Skip to content

Commit 7200857

Browse files
author
ShreyeshArangath
committed
plan caching
1 parent e1d0c81 commit 7200857

4 files changed

Lines changed: 88 additions & 3 deletions

File tree

python/datafusion/plan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from __future__ import annotations
2121

22+
import datetime
23+
2224
from typing import TYPE_CHECKING, Any
2325

2426
import datafusion._internal as df_internal
@@ -290,6 +292,11 @@ def value(self) -> int | None:
290292
"""
291293
return self._raw.value
292294

295+
@property
296+
def value_as_datetime(self) -> datetime.datetime | None:
297+
"""The value as a UTC datetime for timestamp metrics, or ``None``."""
298+
return self._raw.value_as_datetime()
299+
293300
@property
294301
def partition(self) -> int | None:
295302
"""The 0-based partition index this metric applies to.

python/tests/test_plans.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import datetime
19+
20+
import pytest
21+
1822
from datafusion import (
1923
ExecutionPlan,
2024
LogicalPlan,
@@ -150,3 +154,38 @@ def test_execute_stream_partitioned_metrics() -> None:
150154
if ms.output_rows is not None
151155
]
152156
assert 2 in output_rows_values
157+
158+
159+
def test_value_as_datetime() -> None:
160+
ctx = SessionContext()
161+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
162+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
163+
df.collect()
164+
plan = df.execution_plan()
165+
166+
for _, ms in plan.collect_metrics():
167+
for metric in ms.metrics():
168+
if metric.name in ("start_timestamp", "end_timestamp"):
169+
dt = metric.value_as_datetime
170+
assert dt is None or isinstance(dt, datetime.datetime)
171+
if dt is not None:
172+
assert dt.tzinfo is not None
173+
else:
174+
assert metric.value_as_datetime is None
175+
176+
177+
def test_collect_twice_reuses_plan() -> None:
178+
ctx = SessionContext()
179+
ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")
180+
df = ctx.sql("SELECT * FROM t WHERE column1 > 1")
181+
182+
df.collect()
183+
df.collect()
184+
185+
plan = df.execution_plan()
186+
output_rows_values = [
187+
ms.output_rows
188+
for _, ms in plan.collect_metrics()
189+
if ms.output_rows is not None
190+
]
191+
assert len(output_rows_values) > 0

src/dataframe.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,19 @@ impl PyDataFrame {
387387
&self,
388388
py: Python,
389389
) -> PyDataFusionResult<(Arc<dyn DFExecutionPlan>, Arc<TaskContext>)> {
390-
let df = self.df.as_ref().clone();
391-
let plan = wait_for_future(py, df.create_physical_plan())??;
392-
*self.last_plan.lock() = Some(Arc::clone(&plan));
390+
let plan = {
391+
let cached = self.last_plan.lock();
392+
cached.as_ref().map(Arc::clone)
393+
};
394+
let plan = match plan {
395+
Some(p) => p,
396+
None => {
397+
let df = self.df.as_ref().clone();
398+
let new_plan = wait_for_future(py, df.create_physical_plan())??;
399+
*self.last_plan.lock() = Some(Arc::clone(&new_plan));
400+
new_plan
401+
}
402+
};
393403
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
394404
Ok((plan, task_ctx))
395405
}

src/metrics.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,35 @@ impl PyMetric {
122122
}
123123
}
124124

125+
/// Returns the value as a Python `datetime` for `StartTimestamp` / `EndTimestamp`
126+
/// metrics, or `None` for all other metric types.
127+
fn value_as_datetime<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
128+
match self.metric.value() {
129+
MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => {
130+
match ts.value() {
131+
Some(dt) => {
132+
let nanos = dt.timestamp_nanos_opt()
133+
.ok_or_else(|| PyErr::new::<pyo3::exceptions::PyOverflowError, _>(
134+
"timestamp out of range"
135+
))?;
136+
let datetime_mod = py.import("datetime")?;
137+
let datetime_cls = datetime_mod.getattr("datetime")?;
138+
let tz_utc = datetime_mod.getattr("timezone")?.getattr("utc")?;
139+
let secs = nanos / 1_000_000_000;
140+
let micros = (nanos % 1_000_000_000) / 1_000;
141+
let result = datetime_cls.call_method1(
142+
"fromtimestamp",
143+
(secs as f64 + micros as f64 / 1_000_000.0, tz_utc),
144+
)?;
145+
Ok(Some(result))
146+
}
147+
None => Ok(None),
148+
}
149+
}
150+
_ => Ok(None),
151+
}
152+
}
153+
125154
#[getter]
126155
fn partition(&self) -> Option<usize> {
127156
self.metric.partition()

0 commit comments

Comments
 (0)