Skip to content

Commit f6bed11

Browse files
timsaucerclaude
andcommitted
fix: use explicit None checks, widen numeric type hints, and add tests
Replace 7 fragile truthiness checks (x.expr if x else None) with explicit is not None checks to prevent silent None when zero-valued literals are passed. Widen log/power/pow type hints to Expr | int | float with noqa: PYI041 for clarity. Add unit tests for coerce_to_expr helpers and integration tests for pythonic calling conventions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 05f9ab9 commit f6bed11

3 files changed

Lines changed: 161 additions & 10 deletions

File tree

python/datafusion/functions.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ def ln(arg: Expr) -> Expr:
11901190
return Expr(f.ln(arg.expr))
11911191

11921192

1193-
def log(base: Expr | float, num: Expr) -> Expr:
1193+
def log(base: Expr | int | float, num: Expr) -> Expr: # noqa: PYI041
11941194
"""Returns the logarithm of a number for a particular ``base``.
11951195
11961196
Examples:
@@ -1416,7 +1416,7 @@ def position(string: Expr, substring: Expr | str) -> Expr:
14161416
return strpos(string, substring)
14171417

14181418

1419-
def power(base: Expr, exponent: Expr | float) -> Expr:
1419+
def power(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041
14201420
"""Returns ``base`` raised to the power of ``exponent``.
14211421
14221422
Examples:
@@ -1432,7 +1432,7 @@ def power(base: Expr, exponent: Expr | float) -> Expr:
14321432
return Expr(f.power(base.expr, exponent.expr))
14331433

14341434

1435-
def pow(base: Expr, exponent: Expr | float) -> Expr:
1435+
def pow(base: Expr, exponent: Expr | int | float) -> Expr: # noqa: PYI041
14361436
"""Returns ``base`` raised to the power of ``exponent``.
14371437
14381438
See Also:
@@ -1486,7 +1486,11 @@ def regexp_like(
14861486
"""
14871487
regex = coerce_to_expr(regex)
14881488
flags = coerce_to_expr_or_none(flags)
1489-
return Expr(f.regexp_like(string.expr, regex.expr, flags.expr if flags else None))
1489+
return Expr(
1490+
f.regexp_like(
1491+
string.expr, regex.expr, flags.expr if flags is not None else None
1492+
)
1493+
)
14901494

14911495

14921496
def regexp_match(
@@ -1518,7 +1522,11 @@ def regexp_match(
15181522
"""
15191523
regex = coerce_to_expr(regex)
15201524
flags = coerce_to_expr_or_none(flags)
1521-
return Expr(f.regexp_match(string.expr, regex.expr, flags.expr if flags else None))
1525+
return Expr(
1526+
f.regexp_match(
1527+
string.expr, regex.expr, flags.expr if flags is not None else None
1528+
)
1529+
)
15221530

15231531

15241532
def regexp_replace(
@@ -1565,7 +1573,7 @@ def regexp_replace(
15651573
string.expr,
15661574
pattern.expr,
15671575
replacement.expr,
1568-
flags.expr if flags else None,
1576+
flags.expr if flags is not None else None,
15691577
)
15701578
)
15711579

@@ -1606,8 +1614,8 @@ def regexp_count(
16061614
f.regexp_count(
16071615
string.expr,
16081616
pattern.expr,
1609-
start.expr if start else None,
1610-
flags.expr if flags else None,
1617+
start.expr if start is not None else None,
1618+
flags.expr if flags is not None else None,
16111619
)
16121620
)
16131621

@@ -3587,7 +3595,12 @@ def array_slice(
35873595
end = coerce_to_expr(end)
35883596
stride = coerce_to_expr_or_none(stride)
35893597
return Expr(
3590-
f.array_slice(array.expr, begin.expr, end.expr, stride.expr if stride else None)
3598+
f.array_slice(
3599+
array.expr,
3600+
begin.expr,
3601+
end.expr,
3602+
stride.expr if stride is not None else None,
3603+
)
35913604
)
35923605

35933606

@@ -3886,7 +3899,7 @@ def string_to_array(
38863899
f.string_to_array(
38873900
string.expr,
38883901
delimiter.expr,
3889-
null_string.expr if null_string else None,
3902+
null_string.expr if null_string is not None else None,
38903903
)
38913904
)
38923905

python/tests/test_expr.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
TransactionEnd,
5454
TransactionStart,
5555
Values,
56+
coerce_to_expr,
57+
coerce_to_expr_or_none,
5658
ensure_expr,
5759
ensure_expr_list,
5860
)
@@ -1026,6 +1028,49 @@ def test_ensure_expr_list_bytearray():
10261028
ensure_expr_list(bytearray(b"a"))
10271029

10281030

1031+
def test_coerce_to_expr_passes_expr_through():
1032+
e = col("a")
1033+
result = coerce_to_expr(e)
1034+
assert isinstance(result, type(e))
1035+
assert str(result) == str(e)
1036+
1037+
1038+
def test_coerce_to_expr_wraps_int():
1039+
result = coerce_to_expr(42)
1040+
assert isinstance(result, type(lit(42)))
1041+
1042+
1043+
def test_coerce_to_expr_wraps_str():
1044+
result = coerce_to_expr("hello")
1045+
assert isinstance(result, type(lit("hello")))
1046+
1047+
1048+
def test_coerce_to_expr_wraps_float():
1049+
result = coerce_to_expr(3.14)
1050+
assert isinstance(result, type(lit(3.14)))
1051+
1052+
1053+
def test_coerce_to_expr_wraps_bool():
1054+
result = coerce_to_expr(True) # noqa: FBT003
1055+
assert isinstance(result, type(lit(True)))
1056+
1057+
1058+
def test_coerce_to_expr_or_none_returns_none():
1059+
assert coerce_to_expr_or_none(None) is None
1060+
1061+
1062+
def test_coerce_to_expr_or_none_wraps_value():
1063+
result = coerce_to_expr_or_none(42)
1064+
assert isinstance(result, type(lit(42)))
1065+
1066+
1067+
def test_coerce_to_expr_or_none_passes_expr_through():
1068+
e = col("a")
1069+
result = coerce_to_expr_or_none(e)
1070+
assert isinstance(result, type(e))
1071+
assert str(result) == str(e)
1072+
1073+
10291074
@pytest.mark.parametrize(
10301075
"value",
10311076
[

python/tests/test_functions.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,3 +2099,96 @@ def test_gen_series_with_step():
20992099
f.gen_series(literal(1), literal(10), literal(3)).alias("v")
21002100
).collect()
21012101
assert result[0].column(0)[0].as_py() == [1, 4, 7, 10]
2102+
2103+
2104+
class TestPythonicNativeTypes:
2105+
"""Tests for accepting native Python types instead of requiring lit()."""
2106+
2107+
def test_split_part_native(self):
2108+
ctx = SessionContext()
2109+
df = ctx.from_pydict({"a": ["a,b,c"]})
2110+
result = df.select(f.split_part(column("a"), ",", 2).alias("s")).collect()
2111+
assert result[0].column(0)[0].as_py() == "b"
2112+
2113+
def test_encode_native_str(self):
2114+
ctx = SessionContext()
2115+
df = ctx.from_pydict({"a": ["hello"]})
2116+
result = df.select(f.encode(column("a"), "base64").alias("e")).collect()
2117+
assert result[0].column(0)[0].as_py() == "aGVsbG8"
2118+
2119+
def test_date_part_native_str(self):
2120+
ctx = SessionContext()
2121+
df = ctx.from_pydict({"a": ["2021-07-15T00:00:00"]})
2122+
df = df.select(f.to_timestamp(column("a")).alias("a"))
2123+
result = df.select(f.date_part("year", column("a")).alias("y")).collect()
2124+
assert result[0].column(0)[0].as_py() == 2021
2125+
2126+
def test_date_trunc_native_str(self):
2127+
ctx = SessionContext()
2128+
df = ctx.from_pydict({"a": ["2021-07-15T12:34:56"]})
2129+
df = df.select(f.to_timestamp(column("a")).alias("a"))
2130+
result = df.select(f.date_trunc("month", column("a")).alias("t")).collect()
2131+
assert str(result[0].column(0)[0].as_py()) == "2021-07-01 00:00:00"
2132+
2133+
def test_left_native_int(self):
2134+
ctx = SessionContext()
2135+
df = ctx.from_pydict({"a": ["the cat"]})
2136+
result = df.select(f.left(column("a"), 3).alias("l")).collect()
2137+
assert result[0].column(0)[0].as_py() == "the"
2138+
2139+
def test_round_native_int(self):
2140+
ctx = SessionContext()
2141+
df = ctx.from_pydict({"a": [1.567]})
2142+
result = df.select(f.round(column("a"), 2).alias("r")).collect()
2143+
assert result[0].column(0)[0].as_py() == 1.57
2144+
2145+
def test_regexp_count_native(self):
2146+
ctx = SessionContext()
2147+
df = ctx.from_pydict({"a": ["abcabc"]})
2148+
result = df.select(
2149+
f.regexp_count(column("a"), "abc", start=4, flags="i").alias("c")
2150+
).collect()
2151+
assert result[0].column(0)[0].as_py() == 1
2152+
2153+
def test_log_native_int(self):
2154+
ctx = SessionContext()
2155+
df = ctx.from_pydict({"a": [100.0]})
2156+
result = df.select(f.log(10, column("a")).alias("l")).collect()
2157+
assert result[0].column(0)[0].as_py() == 2.0
2158+
2159+
def test_power_native_int(self):
2160+
ctx = SessionContext()
2161+
df = ctx.from_pydict({"a": [2.0]})
2162+
result = df.select(f.power(column("a"), 3).alias("p")).collect()
2163+
assert result[0].column(0)[0].as_py() == 8.0
2164+
2165+
def test_array_slice_native(self):
2166+
ctx = SessionContext()
2167+
df = ctx.from_pydict({"a": [[1, 2, 3, 4]]})
2168+
result = df.select(f.array_slice(column("a"), 2, 3).alias("s")).collect()
2169+
assert result[0].column(0)[0].as_py() == [2, 3]
2170+
2171+
def test_string_to_array_native(self):
2172+
ctx = SessionContext()
2173+
df = ctx.from_pydict({"a": ["hello,NA,world"]})
2174+
result = df.select(
2175+
f.string_to_array(column("a"), ",", null_string="NA").alias("v")
2176+
).collect()
2177+
assert result[0].column(0)[0].as_py() == ["hello", None, "world"]
2178+
2179+
def test_regexp_replace_native(self):
2180+
ctx = SessionContext()
2181+
df = ctx.from_pydict({"a": ["a1 b2 c3"]})
2182+
result = df.select(
2183+
f.regexp_replace(column("a"), r"\d+", "X", flags="g").alias("r")
2184+
).collect()
2185+
assert result[0].column(0)[0].as_py() == "aX bX cX"
2186+
2187+
def test_backward_compat_with_lit(self):
2188+
"""Verify that existing code using lit() still works."""
2189+
ctx = SessionContext()
2190+
df = ctx.from_pydict({"a": ["a,b,c"]})
2191+
result = df.select(
2192+
f.split_part(column("a"), literal(","), literal(2)).alias("s")
2193+
).collect()
2194+
assert result[0].column(0)[0].as_py() == "b"

0 commit comments

Comments
 (0)