|
18 | 18 | use std::collections::HashMap; |
19 | 19 |
|
20 | 20 | use datafusion::common::{Column, ScalarValue, TableReference}; |
21 | | -use datafusion::execution::FunctionRegistry; |
22 | | -use datafusion::functions_aggregate::all_default_aggregate_functions; |
23 | | -use datafusion::functions_window::all_default_window_functions; |
24 | | -use datafusion::logical_expr::expr::{ |
25 | | - Alias, FieldMetadata, NullTreatment as DFNullTreatment, WindowFunction, WindowFunctionParams, |
26 | | -}; |
27 | | -use datafusion::logical_expr::{Expr, ExprFunctionExt, WindowFrame, WindowFunctionDefinition, lit}; |
| 21 | +use datafusion::logical_expr::expr::{Alias, FieldMetadata, NullTreatment as DFNullTreatment}; |
| 22 | +use datafusion::logical_expr::{Expr, ExprFunctionExt, lit}; |
28 | 23 | use datafusion::{functions, functions_aggregate, functions_window}; |
29 | 24 | use pyo3::prelude::*; |
30 | 25 | use pyo3::wrap_pyfunction; |
31 | 26 |
|
32 | 27 | use crate::common::data_type::{NullTreatment, PyScalarValue}; |
33 | | -use crate::context::PySessionContext; |
34 | | -use crate::errors::{PyDataFusionError, PyDataFusionResult}; |
| 28 | +use crate::errors::PyDataFusionResult; |
35 | 29 | use crate::expr::PyExpr; |
36 | 30 | use crate::expr::conditional_expr::PyCaseBuilder; |
37 | 31 | use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; |
@@ -306,126 +300,6 @@ fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> { |
306 | 300 | Ok(PyCaseBuilder::new(None).when(when, then)) |
307 | 301 | } |
308 | 302 |
|
309 | | -/// Helper function to find the appropriate window function. |
310 | | -/// |
311 | | -/// Search procedure: |
312 | | -/// 1) Search built in window functions, which are being deprecated. |
313 | | -/// 1) If a session context is provided: |
314 | | -/// 1) search User Defined Aggregate Functions (UDAFs) |
315 | | -/// 1) search registered window functions |
316 | | -/// 1) search registered aggregate functions |
317 | | -/// 1) If no function has been found, search default aggregate functions. |
318 | | -/// |
319 | | -/// NOTE: we search the built-ins first because the `UDAF` versions currently do not have the same behavior. |
320 | | -fn find_window_fn( |
321 | | - name: &str, |
322 | | - ctx: Option<PySessionContext>, |
323 | | -) -> PyDataFusionResult<WindowFunctionDefinition> { |
324 | | - if let Some(ctx) = ctx { |
325 | | - // search UDAFs |
326 | | - let udaf = ctx |
327 | | - .ctx |
328 | | - .udaf(name) |
329 | | - .map(WindowFunctionDefinition::AggregateUDF) |
330 | | - .ok(); |
331 | | - |
332 | | - if let Some(udaf) = udaf { |
333 | | - return Ok(udaf); |
334 | | - } |
335 | | - |
336 | | - let session_state = ctx.ctx.state(); |
337 | | - |
338 | | - // search registered window functions |
339 | | - let window_fn = session_state |
340 | | - .window_functions() |
341 | | - .get(name) |
342 | | - .map(|f| WindowFunctionDefinition::WindowUDF(f.clone())); |
343 | | - |
344 | | - if let Some(window_fn) = window_fn { |
345 | | - return Ok(window_fn); |
346 | | - } |
347 | | - |
348 | | - // search registered aggregate functions |
349 | | - let agg_fn = session_state |
350 | | - .aggregate_functions() |
351 | | - .get(name) |
352 | | - .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())); |
353 | | - |
354 | | - if let Some(agg_fn) = agg_fn { |
355 | | - return Ok(agg_fn); |
356 | | - } |
357 | | - } |
358 | | - |
359 | | - // search default aggregate functions |
360 | | - let agg_fn = all_default_aggregate_functions() |
361 | | - .iter() |
362 | | - .find(|v| v.name() == name || v.aliases().contains(&name.to_string())) |
363 | | - .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone())); |
364 | | - |
365 | | - if let Some(agg_fn) = agg_fn { |
366 | | - return Ok(agg_fn); |
367 | | - } |
368 | | - |
369 | | - // search default window functions |
370 | | - let window_fn = all_default_window_functions() |
371 | | - .iter() |
372 | | - .find(|v| v.name() == name || v.aliases().contains(&name.to_string())) |
373 | | - .map(|f| WindowFunctionDefinition::WindowUDF(f.clone())); |
374 | | - |
375 | | - if let Some(window_fn) = window_fn { |
376 | | - return Ok(window_fn); |
377 | | - } |
378 | | - |
379 | | - Err(PyDataFusionError::Common(format!( |
380 | | - "window function `{name}` not found" |
381 | | - ))) |
382 | | -} |
383 | | - |
384 | | -/// Creates a new Window function expression |
385 | | -#[allow(clippy::too_many_arguments)] |
386 | | -#[pyfunction] |
387 | | -#[pyo3(signature = (name, args, partition_by=None, order_by=None, window_frame=None, filter=None, distinct=false, ctx=None))] |
388 | | -fn window( |
389 | | - name: &str, |
390 | | - args: Vec<PyExpr>, |
391 | | - partition_by: Option<Vec<PyExpr>>, |
392 | | - order_by: Option<Vec<PySortExpr>>, |
393 | | - window_frame: Option<PyWindowFrame>, |
394 | | - filter: Option<PyExpr>, |
395 | | - distinct: bool, |
396 | | - ctx: Option<PySessionContext>, |
397 | | -) -> PyResult<PyExpr> { |
398 | | - let fun = find_window_fn(name, ctx)?; |
399 | | - |
400 | | - let window_frame = window_frame |
401 | | - .map(|w| w.into()) |
402 | | - .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty()))); |
403 | | - let filter = filter.map(|f| f.expr.into()); |
404 | | - |
405 | | - Ok(PyExpr { |
406 | | - expr: datafusion::logical_expr::Expr::WindowFunction(Box::new(WindowFunction { |
407 | | - fun, |
408 | | - params: WindowFunctionParams { |
409 | | - args: args.into_iter().map(|x| x.expr).collect::<Vec<_>>(), |
410 | | - partition_by: partition_by |
411 | | - .unwrap_or_default() |
412 | | - .into_iter() |
413 | | - .map(|x| x.expr) |
414 | | - .collect::<Vec<_>>(), |
415 | | - order_by: order_by |
416 | | - .unwrap_or_default() |
417 | | - .into_iter() |
418 | | - .map(|x| x.into()) |
419 | | - .collect::<Vec<_>>(), |
420 | | - window_frame, |
421 | | - filter, |
422 | | - distinct, |
423 | | - null_treatment: None, |
424 | | - }, |
425 | | - })), |
426 | | - }) |
427 | | -} |
428 | | - |
429 | 303 | // Generates a [pyo3] wrapper for associated aggregate functions. |
430 | 304 | // All of the builder options are exposed to the python internal |
431 | 305 | // function and we rely on the wrappers to only use those that |
@@ -1186,7 +1060,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { |
1186 | 1060 | m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision |
1187 | 1061 | m.add_wrapped(wrap_pyfunction!(var_pop))?; |
1188 | 1062 | m.add_wrapped(wrap_pyfunction!(var_sample))?; |
1189 | | - m.add_wrapped(wrap_pyfunction!(window))?; |
1190 | 1063 | m.add_wrapped(wrap_pyfunction!(regr_avgx))?; |
1191 | 1064 | m.add_wrapped(wrap_pyfunction!(regr_avgy))?; |
1192 | 1065 | m.add_wrapped(wrap_pyfunction!(regr_count))?; |
|
0 commit comments