Skip to content

Commit 8b4d579

Browse files
committed
Add to_batches() and interpolate() methods to DataFrame
- Add to_batches() as alias for collect() returning RecordBatch list - Add interpolate() method with forward_fill support - Add deprecation warning to collect() method - Add comprehensive tests for both methods Addresses items from RFC apache#875
1 parent c609dfa commit 8b4d579

2 files changed

Lines changed: 83 additions & 1 deletion

File tree

python/datafusion/dataframe.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import Expr, SortExpr, sort_or_default, Window
4444
from datafusion.plan import ExecutionPlan, LogicalPlan
4545
from datafusion.record_batch import RecordBatchStream
46+
from datafusion.functions import col, nvl, last_value
47+
from datafusion.common import NullTreatment
4648

4749
if TYPE_CHECKING:
4850
import pathlib
@@ -360,6 +362,9 @@ def describe(self) -> DataFrame:
360362
"""
361363
return DataFrame(self.df.describe())
362364

365+
@deprecated(
366+
"schema() is deprecated. Use :py:meth:`~DataFrame.get_schema` instead"
367+
)
363368
def schema(self) -> pa.Schema:
364369
"""Return the :py:class:`pyarrow.Schema` of this DataFrame.
365370
@@ -370,6 +375,39 @@ def schema(self) -> pa.Schema:
370375
Describing schema of the DataFrame
371376
"""
372377
return self.df.schema()
378+
379+
def to_batches(self) -> list[pa.RecordBatch]:
380+
"""Convert DataFrame to list of RecordBatches."""
381+
return self.collect() # delegate to existing method
382+
383+
def interpolate(self, method: str = "forward_fill", **kwargs) -> DataFrame:
384+
"""Interpolate missing values per column.
385+
386+
Args:
387+
method: Interpolation method ('linear', 'forward_fill', 'backward_fill')
388+
389+
Returns:
390+
DataFrame with interpolated values
391+
392+
Raises:
393+
NotImplementedError: Linear interpolation not yet supported
394+
"""
395+
if method == "forward_fill":
396+
exprs = []
397+
for field in self.schema():
398+
window = Window(order_by=col(field.name))
399+
expr = nvl(col(field.name),last_value(col(field.name)).over(window)).alias(field.name)
400+
exprs.append(expr)
401+
return self.select(*exprs)
402+
403+
elif method == "backward_fill":
404+
raise NotImplementedError("backward_fill not yet implemented")
405+
406+
elif method == "linear":
407+
raise NotImplementedError("Linear interpolation requires complex window function logic")
408+
409+
else:
410+
raise ValueError(f"Unknown interpolation method: {method}")
373411

374412
@deprecated(
375413
"select_columns() is deprecated. Use :py:meth:`~DataFrame.select` instead"
@@ -592,6 +630,9 @@ def tail(self, n: int = 5) -> DataFrame:
592630
"""
593631
return DataFrame(self.df.limit(n, max(0, self.count() - n)))
594632

633+
@deprecated(
634+
"collect() returning RecordBatch list is deprecated. Use to_batches() for RecordBatch list or collect() will return DataFrame in future versions"
635+
)
595636
def collect(self) -> list[pa.RecordBatch]:
596637
"""Execute this :py:class:`DataFrame` and collect results into memory.
597638

python/tests/test_dataframe.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,47 @@ def get_header_style(self) -> str:
185185
"padding: 10px; border: 1px solid #3367d6;"
186186
)
187187

188+
def test_to_batches(df):
189+
"""Test to_batches method returns list of RecordBatches."""
190+
batches = df.to_batches()
191+
assert isinstance(batches, list)
192+
assert len(batches) > 0
193+
assert all(isinstance(batch, pa.RecordBatch) for batch in batches)
194+
195+
196+
collect_batches = df.collect()
197+
assert len(batches) == len(collect_batches)
198+
for i, batch in enumerate(batches):
199+
assert batch.equals(collect_batches[i])
200+
201+
202+
def test_interpolate_forward_fill(ctx):
203+
"""Test interpolate method with forward_fill."""
204+
205+
batch = pa.RecordBatch.from_arrays(
206+
[pa.array([1, None, 3, None]), pa.array([4.0, None, 6.0, None])],
207+
names=["int_col", "float_col"],
208+
)
209+
df = ctx.create_dataframe([[batch]])
210+
211+
result = df.interpolate("forward_fill")
212+
213+
assert isinstance(result, DataFrame)
214+
215+
216+
def test_interpolate_unsupported_method(ctx):
217+
"""Test interpolate with unsupported method raises error."""
218+
batch = pa.RecordBatch.from_arrays(
219+
[pa.array([1, 2, 3])], names=["a"]
220+
)
221+
df = ctx.create_dataframe([[batch]])
222+
223+
with pytest.raises(NotImplementedError, match="requires complex window"):
224+
df.interpolate("linear")
225+
226+
with pytest.raises(ValueError, match="Unknown interpolation method"):
227+
df.interpolate("unknown")
228+
188229

189230
def count_table_rows(html_content: str) -> int:
190231
"""Count the number of table rows in HTML content.

0 commit comments

Comments
 (0)