Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use datafusion_python_util::{
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx,
get_tokio_runtime, spawn_future, wait_for_future,
get_tokio_runtime, set_global_ctx, spawn_future, wait_for_future,
};
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
Expand Down Expand Up @@ -407,11 +407,22 @@ impl PySessionContext {
#[staticmethod]
#[pyo3(signature = ())]
pub fn global_ctx() -> PyResult<Self> {
let ctx = get_global_ctx().clone();
let ctx = get_global_ctx();
let logical_codec = Self::default_logical_codec(&ctx);
Ok(Self { ctx, logical_codec })
}

/// Replace the process-wide global `SessionContext` with this one.
///
/// All subsequent callers of `SessionContext.global_ctx()` (and Rust
/// helpers that fall back to the global context, such as the
/// `read_parquet` / `read_csv` / etc. module-level helpers) will see this
/// context. Existing references already obtained from `global_ctx()` are
/// not affected.
pub fn set_as_global(&self) {
set_global_ctx(self.ctx.clone());
}

/// Register an object store with the given name
#[pyo3(signature = (scheme, store, host=None))]
pub fn register_object_store(
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ impl PyDataFrame {
Some(f) => f
.parse::<datafusion::common::format::ExplainFormat>()
.map_err(|e| {
PyDataFusionError::Common(format!("Invalid explain format '{}': {}", f, e))
PyDataFusionError::Common(format!("Invalid explain format '{f}': {e}"))
})?,
None => datafusion::common::format::ExplainFormat::Indent,
};
Expand Down
26 changes: 26 additions & 0 deletions crates/core/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ use datafusion::logical_expr::{
Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder, ExprFunctionExt, Like, LogicalPlan,
Operator, TryCast, WindowFunctionDefinition, col, lit, lit_with_metadata,
};
use datafusion_proto::bytes::Serializeable;
use datafusion_python_util::get_global_ctx;
use pyo3::IntoPyObjectExt;
use pyo3::basic::CompareOp;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use window::PyWindowFrame;

use self::alias::PyAlias;
Expand Down Expand Up @@ -256,6 +259,29 @@ impl PyExpr {
Ok(format!("Expr({})", self.expr))
}

/// Serialize the underlying expression to bytes via the `datafusion-proto`
/// wire format. Used by the Python `Expr` wrapper to implement
/// `__getstate__` / `__setstate__`; also exposed directly so callers can
/// persist or transmit expressions without going through `pickle`.
fn to_bytes<'py>(&self, py: Python<'py>) -> PyDataFusionResult<Bound<'py, PyBytes>> {
let bytes = self.expr.to_bytes()?;
Ok(PyBytes::new(py, &bytes))
}

/// Reconstruct a `RawExpr` from bytes produced by [`PyExpr::to_bytes`].
///
/// Function references (built-ins, UDFs, UDAFs, UDWFs) are resolved by
/// name against the process-wide global `SessionContext`. Built-in
/// functions are registered on every fresh context, so they always
/// roundtrip. To roundtrip user-defined functions, register them on a
/// context and call `SessionContext.set_as_global()` before unpickling.
#[staticmethod]
fn from_bytes(bytes: &[u8]) -> PyDataFusionResult<PyExpr> {
let ctx = get_global_ctx();
let expr = Expr::from_bytes_with_registry(bytes, ctx.as_ref())?;
Ok(expr.into())
}

fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
Ok((self.expr.clone() + rhs.expr).into())
}
Expand Down
65 changes: 60 additions & 5 deletions crates/util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::future::Future;
use std::ptr::NonNull;
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Duration;

use datafusion::datasource::TableProvider;
Expand Down Expand Up @@ -59,11 +59,29 @@ pub fn is_ipython_env(py: Python) -> &'static bool {
})
}

/// Utility to get the Global Datafussion CTX
fn global_ctx_slot() -> &'static RwLock<Arc<SessionContext>> {
static CTX: OnceLock<RwLock<Arc<SessionContext>>> = OnceLock::new();
CTX.get_or_init(|| RwLock::new(Arc::new(SessionContext::new())))
}

/// Utility to get the Global DataFusion CTX.
///
/// Returns an owned `Arc<SessionContext>` snapshot. The underlying slot can be
/// replaced via [`set_global_ctx`]; existing snapshots are unaffected.
#[inline]
pub fn get_global_ctx() -> &'static Arc<SessionContext> {
static CTX: OnceLock<Arc<SessionContext>> = OnceLock::new();
CTX.get_or_init(|| Arc::new(SessionContext::new()))
pub fn get_global_ctx() -> Arc<SessionContext> {
global_ctx_slot()
.read()
.expect("global SessionContext lock poisoned")
.clone()
}

/// Replace the Global DataFusion CTX. Subsequent calls to [`get_global_ctx`]
/// will return the new context. Already-cloned `Arc`s are not affected.
pub fn set_global_ctx(ctx: Arc<SessionContext>) {
*global_ctx_slot()
.write()
.expect("global SessionContext lock poisoned") = ctx;
}

/// Utility to collect rust futures with GIL released and respond to
Expand Down Expand Up @@ -224,3 +242,40 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound<PyAny>) -> PyResult<FFI_Logic

Ok(codec.clone())
}

#[cfg(test)]
mod tests {
use super::*;

/// The global slot must round-trip a custom `SessionContext`. Since the
/// global is process-wide, this test only asserts identity through a
/// single set/get cycle and restores the prior value at the end so the
/// test is independent of ordering with other tests in the binary.
#[test]
fn set_global_ctx_replaces_default() {
let prior = get_global_ctx();
let custom = Arc::new(SessionContext::new());
let custom_ptr = Arc::as_ptr(&custom);

set_global_ctx(custom.clone());
let observed = get_global_ctx();
assert_eq!(
Arc::as_ptr(&observed),
custom_ptr,
"get_global_ctx should return the context installed by set_global_ctx",
);

// A snapshot taken before the swap should be unaffected after another
// set_global_ctx call, because get_global_ctx clones the Arc.
let snapshot = get_global_ctx();
let replacement = Arc::new(SessionContext::new());
set_global_ctx(replacement);
assert_eq!(
Arc::as_ptr(&snapshot),
custom_ptr,
"previously cloned snapshots must not be invalidated by set_global_ctx",
);

set_global_ctx(prior);
}
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ ignore-words-list = ["IST", "ans"]
dev = [
"arro3-core==0.6.5",
"codespell==2.4.1",
"dill>=0.3.8",
"maturin>=1.8.1",
"nanoarrow==0.8.0",
"numpy>1.25.0;python_version<'3.14'",
Expand Down
16 changes: 16 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,22 @@ def global_ctx(cls) -> SessionContext:
wrapper.ctx = internal_ctx
return wrapper

def set_as_global(self) -> None:
"""Install this context as the process-wide global ``SessionContext``.

After this call, :meth:`SessionContext.global_ctx` (and the module-level
helpers in :mod:`datafusion.io` that fall back to the global context)
will return this context. Existing references already obtained from
``global_ctx()`` are not invalidated.

Example::

ctx = SessionContext()
ctx.register_udf(my_udf)
ctx.set_as_global()
"""
self.ctx.set_as_global()

def enable_url_table(self) -> SessionContext:
"""Control if local files can be queried as tables.

Expand Down
25 changes: 25 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,31 @@ def __init__(self, expr: expr_internal.RawExpr) -> None:
"""This constructor should not be called by the end user."""
self.expr = expr

def to_bytes(self) -> bytes:
"""Serialize this expression to bytes via the ``datafusion-proto`` wire format.

Function references (built-ins and UDFs/UDAFs/UDWFs) are encoded by
name; on :py:meth:`from_bytes` the names are resolved against the
process-wide global :py:class:`SessionContext`. Built-in functions
always roundtrip; for user-defined functions, register them on a
context and call :py:meth:`SessionContext.set_as_global` before
loading.
"""
return self.expr.to_bytes()

@classmethod
def from_bytes(cls, data: bytes) -> Expr:
"""Inverse of :py:meth:`to_bytes`. See that method for caveats."""
return cls(expr_internal.RawExpr.from_bytes(data))

def __getstate__(self) -> bytes:
"""Serialize for ``pickle`` / ``dill``. Delegates to :py:meth:`to_bytes`."""
return self.to_bytes()

def __setstate__(self, state: bytes) -> None:
"""Inverse of :py:meth:`__getstate__`."""
self.expr = expr_internal.RawExpr.from_bytes(state)

def to_variant(self) -> Any:
"""Convert this expression into a python object if possible."""
return self.expr.to_variant()
Expand Down
167 changes: 167 additions & 0 deletions python/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Pickle / dill roundtrip tests for :py:class:`datafusion.Expr`.

The wire format is `datafusion-proto`'s ``LogicalExprNode``. Function
references are encoded by name, so unpickling resolves them against the
process-wide global :py:class:`SessionContext`. Tests that need a
non-built-in function temporarily install a custom global context and
restore the previous one.
"""

import pickle
from contextlib import contextmanager

import dill
import pyarrow as pa
import pytest
from datafusion import SessionContext, col, lit, udf
from datafusion import functions as f
from datafusion.expr import Expr


@pytest.fixture
def ctx():
return SessionContext()


@pytest.fixture
def df(ctx):
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, None])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]], name="t")


@contextmanager
def temporary_global_ctx(new_ctx):
"""Install ``new_ctx`` as the process-wide global and restore on exit."""
previous = SessionContext.global_ctx()
new_ctx.set_as_global()
try:
yield
finally:
previous.set_as_global()


@pytest.mark.parametrize("dumper", [pickle, dill], ids=["pickle", "dill"])
@pytest.mark.parametrize(
"build_expr",
[
pytest.param(lambda: col("a"), id="column"),
pytest.param(lambda: lit(42), id="literal_int"),
pytest.param(lambda: lit("hello"), id="literal_str"),
pytest.param(lambda: col("a") + lit(1), id="binary_add"),
pytest.param(lambda: (col("a") * lit(2)) - col("b"), id="binary_nested"),
pytest.param(lambda: col("a").alias("renamed"), id="alias"),
pytest.param(lambda: col("a").cast(pa.float64()), id="cast"),
pytest.param(lambda: col("a").is_null(), id="is_null"),
pytest.param(lambda: col("a").between(lit(1), lit(10)), id="between"),
pytest.param(lambda: ~(col("a") > lit(0)), id="not_gt"),
pytest.param(lambda: f.sum(col("a")), id="agg_sum"),
pytest.param(
lambda: f.case(col("a")).when(lit(1), lit("one")).end(),
id="case_when",
),
],
)
def test_builtin_roundtrip(build_expr, dumper):
"""Built-in expressions roundtrip via pickle and dill."""
expr = build_expr()
restored = dumper.loads(dumper.dumps(expr))
assert isinstance(restored, Expr)
# canonical_name() gives a full string form including function names,
# so equal canonical names imply structural equivalence.
assert restored.canonical_name() == expr.canonical_name()


@pytest.mark.parametrize("dumper", [pickle, dill], ids=["pickle", "dill"])
def test_pickled_expr_executes(df, dumper):
"""A roundtripped expression evaluates to the same result as the original."""
expr = (col("a") + lit(10)).alias("a_plus_ten")
restored = dumper.loads(dumper.dumps(expr))

original = df.select(expr).collect()[0].column(0)
after = df.select(restored).collect()[0].column(0)
assert original == after
assert original == pa.array([11, 12, 13], type=pa.int64())


def test_udf_roundtrip_via_global_ctx():
"""UDFs roundtrip when registered on the active global context.

Mirrors the documented usage of ``SessionContext.set_as_global``.
"""
is_null = udf(
lambda x: x.is_null(),
[pa.int64()],
pa.bool_(),
volatility="immutable",
name="pickle_test_is_null",
)

custom_ctx = SessionContext()
custom_ctx.register_udf(is_null)

expr = is_null(col("b"))

with temporary_global_ctx(custom_ctx):
data = pickle.dumps(expr)
restored = pickle.loads(data) # noqa: S301
assert restored.canonical_name() == expr.canonical_name()

# Also evaluate to confirm the UDF body is wired up post-roundtrip.
batch = pa.RecordBatch.from_arrays([pa.array([1, None, 3])], names=["b"])
df = custom_ctx.create_dataframe([[batch]], name="t_udf")
result = df.select(restored.alias("nul")).collect()[0].column(0)
assert result == pa.array([False, True, False])


def test_udf_roundtrip_fails_without_registration():
"""Without the UDF registered on the global context, unpickle errors out
rather than silently substituting a different implementation."""
is_null = udf(
lambda x: x.is_null(),
[pa.int64()],
pa.bool_(),
volatility="immutable",
name="pickle_test_unknown_udf",
)
expr = is_null(col("b"))

data = pickle.dumps(expr)
# The default global ctx does not have this UDF registered. Reconstruction
# must raise rather than silently substitute a placeholder. DataFusion
# surfaces this as a generic Python ``Exception`` whose message names the
# missing function, so match on the function name.
with pytest.raises(Exception, match="pickle_test_unknown_udf"):
pickle.loads(data) # noqa: S301


def test_getstate_returns_bytes():
"""``__getstate__`` is exposed directly and returns raw bytes — useful for
callers that want to persist or transmit expressions without pickle."""
expr = col("a") + lit(1)
state = expr.__getstate__()
assert isinstance(state, bytes)
assert len(state) > 0

rebuilt = Expr.__new__(Expr)
rebuilt.__setstate__(state)
assert rebuilt.canonical_name() == expr.canonical_name()
Loading
Loading