Skip to content

Commit 5796b53

Browse files
ntjohnson1claude
andcommitted
feat: pickle/dill support for Expr
Add `to_bytes` / `from_bytes` on `Expr` (Python wrapper) and the underlying `RawExpr` (Rust). Serialization uses `datafusion-proto`'s `Serializeable` trait, encoding function references by name. The Python wrapper implements `__getstate__` / `__setstate__` on top, so `pickle.dumps` / `dill.dumps` work out of the box. Reconstruction resolves function names against the process-wide global `SessionContext` (introduced as settable in the previous commit). Built-in functions always roundtrip; user-defined functions roundtrip when registered on a context that has been installed via `SessionContext.set_as_global()`. Adds `dill` to the dev dependency group and parametrized tests covering both serializers across columns, literals, binary ops, casts, between, aggregates, case/when, and a UDF with the global-ctx pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 23543fc commit 5796b53

5 files changed

Lines changed: 230 additions & 0 deletions

File tree

crates/core/src/expr.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ use datafusion::logical_expr::{
3131
Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder, ExprFunctionExt, Like, LogicalPlan,
3232
Operator, TryCast, WindowFunctionDefinition, col, lit, lit_with_metadata,
3333
};
34+
use datafusion_proto::bytes::Serializeable;
35+
use datafusion_python_util::get_global_ctx;
3436
use pyo3::IntoPyObjectExt;
3537
use pyo3::basic::CompareOp;
3638
use pyo3::prelude::*;
39+
use pyo3::types::PyBytes;
3740
use window::PyWindowFrame;
3841

3942
use self::alias::PyAlias;
@@ -256,6 +259,29 @@ impl PyExpr {
256259
Ok(format!("Expr({})", self.expr))
257260
}
258261

262+
/// Serialize the underlying expression to bytes via the `datafusion-proto`
263+
/// wire format. Used by the Python `Expr` wrapper to implement
264+
/// `__getstate__` / `__setstate__`; also exposed directly so callers can
265+
/// persist or transmit expressions without going through `pickle`.
266+
fn to_bytes<'py>(&self, py: Python<'py>) -> PyDataFusionResult<Bound<'py, PyBytes>> {
267+
let bytes = self.expr.to_bytes()?;
268+
Ok(PyBytes::new(py, &bytes))
269+
}
270+
271+
/// Reconstruct a `RawExpr` from bytes produced by [`PyExpr::to_bytes`].
272+
///
273+
/// Function references (built-ins, UDFs, UDAFs, UDWFs) are resolved by
274+
/// name against the process-wide global `SessionContext`. Built-in
275+
/// functions are registered on every fresh context, so they always
276+
/// roundtrip. To roundtrip user-defined functions, register them on a
277+
/// context and call `SessionContext.set_as_global()` before unpickling.
278+
#[staticmethod]
279+
fn from_bytes(bytes: &[u8]) -> PyDataFusionResult<PyExpr> {
280+
let ctx = get_global_ctx();
281+
let expr = Expr::from_bytes_with_registry(bytes, ctx.as_ref())?;
282+
Ok(expr.into())
283+
}
284+
259285
fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
260286
Ok((self.expr.clone() + rhs.expr).into())
261287
}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ ignore-words-list = ["IST", "ans"]
188188
dev = [
189189
"arro3-core==0.6.5",
190190
"codespell==2.4.1",
191+
"dill>=0.3.8",
191192
"maturin>=1.8.1",
192193
"nanoarrow==0.8.0",
193194
"numpy>1.25.0;python_version<'3.14'",

python/datafusion/expr.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,31 @@ def __init__(self, expr: expr_internal.RawExpr) -> None:
410410
"""This constructor should not be called by the end user."""
411411
self.expr = expr
412412

413+
def to_bytes(self) -> bytes:
414+
"""Serialize this expression to bytes via the ``datafusion-proto`` wire format.
415+
416+
Function references (built-ins and UDFs/UDAFs/UDWFs) are encoded by
417+
name; on :py:meth:`from_bytes` the names are resolved against the
418+
process-wide global :py:class:`SessionContext`. Built-in functions
419+
always roundtrip; for user-defined functions, register them on a
420+
context and call :py:meth:`SessionContext.set_as_global` before
421+
loading.
422+
"""
423+
return self.expr.to_bytes()
424+
425+
@classmethod
426+
def from_bytes(cls, data: bytes) -> Expr:
427+
"""Inverse of :py:meth:`to_bytes`. See that method for caveats."""
428+
return cls(expr_internal.RawExpr.from_bytes(data))
429+
430+
def __getstate__(self) -> bytes:
431+
"""Serialize for ``pickle`` / ``dill``. Delegates to :py:meth:`to_bytes`."""
432+
return self.to_bytes()
433+
434+
def __setstate__(self, state: bytes) -> None:
435+
"""Inverse of :py:meth:`__getstate__`."""
436+
self.expr = expr_internal.RawExpr.from_bytes(state)
437+
413438
def to_variant(self) -> Any:
414439
"""Convert this expression into a python object if possible."""
415440
return self.expr.to_variant()

python/tests/test_pickle.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Pickle / dill roundtrip tests for :py:class:`datafusion.Expr`.
19+
20+
The wire format is `datafusion-proto`'s ``LogicalExprNode``. Function
21+
references are encoded by name, so unpickling resolves them against the
22+
process-wide global :py:class:`SessionContext`. Tests that need a
23+
non-built-in function temporarily install a custom global context and
24+
restore the previous one.
25+
"""
26+
27+
import pickle
28+
from contextlib import contextmanager
29+
30+
import dill
31+
import pyarrow as pa
32+
import pytest
33+
from datafusion import SessionContext, col, lit, udf
34+
from datafusion import functions as f
35+
from datafusion.expr import Expr
36+
37+
38+
@pytest.fixture
39+
def ctx():
40+
return SessionContext()
41+
42+
43+
@pytest.fixture
44+
def df(ctx):
45+
batch = pa.RecordBatch.from_arrays(
46+
[pa.array([1, 2, 3]), pa.array([4, 5, None])],
47+
names=["a", "b"],
48+
)
49+
return ctx.create_dataframe([[batch]], name="t")
50+
51+
52+
@contextmanager
53+
def temporary_global_ctx(new_ctx):
54+
"""Install ``new_ctx`` as the process-wide global and restore on exit."""
55+
previous = SessionContext.global_ctx()
56+
new_ctx.set_as_global()
57+
try:
58+
yield
59+
finally:
60+
previous.set_as_global()
61+
62+
63+
@pytest.mark.parametrize("dumper", [pickle, dill], ids=["pickle", "dill"])
64+
@pytest.mark.parametrize(
65+
"build_expr",
66+
[
67+
pytest.param(lambda: col("a"), id="column"),
68+
pytest.param(lambda: lit(42), id="literal_int"),
69+
pytest.param(lambda: lit("hello"), id="literal_str"),
70+
pytest.param(lambda: col("a") + lit(1), id="binary_add"),
71+
pytest.param(lambda: (col("a") * lit(2)) - col("b"), id="binary_nested"),
72+
pytest.param(lambda: col("a").alias("renamed"), id="alias"),
73+
pytest.param(lambda: col("a").cast(pa.float64()), id="cast"),
74+
pytest.param(lambda: col("a").is_null(), id="is_null"),
75+
pytest.param(lambda: col("a").between(lit(1), lit(10)), id="between"),
76+
pytest.param(lambda: ~(col("a") > lit(0)), id="not_gt"),
77+
pytest.param(lambda: f.sum(col("a")), id="agg_sum"),
78+
pytest.param(
79+
lambda: f.case(col("a")).when(lit(1), lit("one")).end(),
80+
id="case_when",
81+
),
82+
],
83+
)
84+
def test_builtin_roundtrip(build_expr, dumper):
85+
"""Built-in expressions roundtrip via pickle and dill."""
86+
expr = build_expr()
87+
restored = dumper.loads(dumper.dumps(expr))
88+
assert isinstance(restored, Expr)
89+
# canonical_name() gives a full string form including function names,
90+
# so equal canonical names imply structural equivalence.
91+
assert restored.canonical_name() == expr.canonical_name()
92+
93+
94+
@pytest.mark.parametrize("dumper", [pickle, dill], ids=["pickle", "dill"])
95+
def test_pickled_expr_executes(df, dumper):
96+
"""A roundtripped expression evaluates to the same result as the original."""
97+
expr = (col("a") + lit(10)).alias("a_plus_ten")
98+
restored = dumper.loads(dumper.dumps(expr))
99+
100+
original = df.select(expr).collect()[0].column(0)
101+
after = df.select(restored).collect()[0].column(0)
102+
assert original == after
103+
assert original == pa.array([11, 12, 13], type=pa.int64())
104+
105+
106+
def test_udf_roundtrip_via_global_ctx():
107+
"""UDFs roundtrip when registered on the active global context.
108+
109+
Mirrors the documented usage of ``SessionContext.set_as_global``.
110+
"""
111+
is_null = udf(
112+
lambda x: x.is_null(),
113+
[pa.int64()],
114+
pa.bool_(),
115+
volatility="immutable",
116+
name="pickle_test_is_null",
117+
)
118+
119+
custom_ctx = SessionContext()
120+
custom_ctx.register_udf(is_null)
121+
122+
expr = is_null(col("b"))
123+
124+
with temporary_global_ctx(custom_ctx):
125+
data = pickle.dumps(expr)
126+
restored = pickle.loads(data) # noqa: S301
127+
assert restored.canonical_name() == expr.canonical_name()
128+
129+
# Also evaluate to confirm the UDF body is wired up post-roundtrip.
130+
batch = pa.RecordBatch.from_arrays([pa.array([1, None, 3])], names=["b"])
131+
df = custom_ctx.create_dataframe([[batch]], name="t_udf")
132+
result = df.select(restored.alias("nul")).collect()[0].column(0)
133+
assert result == pa.array([False, True, False])
134+
135+
136+
def test_udf_roundtrip_fails_without_registration():
137+
"""Without the UDF registered on the global context, unpickle errors out
138+
rather than silently substituting a different implementation."""
139+
is_null = udf(
140+
lambda x: x.is_null(),
141+
[pa.int64()],
142+
pa.bool_(),
143+
volatility="immutable",
144+
name="pickle_test_unknown_udf",
145+
)
146+
expr = is_null(col("b"))
147+
148+
data = pickle.dumps(expr)
149+
# The default global ctx does not have this UDF registered. Reconstruction
150+
# must raise rather than silently substitute a placeholder. DataFusion
151+
# surfaces this as a generic Python ``Exception`` whose message names the
152+
# missing function, so match on the function name.
153+
with pytest.raises(Exception, match="pickle_test_unknown_udf"):
154+
pickle.loads(data) # noqa: S301
155+
156+
157+
def test_getstate_returns_bytes():
158+
"""``__getstate__`` is exposed directly and returns raw bytes — useful for
159+
callers that want to persist or transmit expressions without pickle."""
160+
expr = col("a") + lit(1)
161+
state = expr.__getstate__()
162+
assert isinstance(state, bytes)
163+
assert len(state) > 0
164+
165+
rebuilt = Expr.__new__(Expr)
166+
rebuilt.__setstate__(state)
167+
assert rebuilt.canonical_name() == expr.canonical_name()

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)