Skip to content

Commit 91f96cb

Browse files
timsaucerclaude
andcommitted
tpch examples: rewrite non-idiomatic queries in idiomatic DataFrame form
Rewrite the seven TPC-H example queries that did not demonstrate the idiomatic DataFrame pattern. The remaining queries (Q02/Q11/Q15/Q17/Q22, which use window functions in place of correlated subqueries) already are idiomatic and are left unchanged. - Q04: replace `.aggregate([col("l_orderkey")], [])` with `.select("l_orderkey").distinct()`, which is the natural way to express "reduce to one row per order" on a DataFrame. - Q07: remove the CASE-as-filter on `n_name` and use `F.in_list(col("n_name"), [nation_1, nation_2])` instead. Drops a comment block that admitted the filter form was simpler. - Q08: rewrite the switched CASE `F.case(...).when(lit(False), ...)` as a searched `F.when(col(...).is_not_null(), ...).otherwise(...)`. That mirrors the reference SQL's `case when ... then ... else 0 end` shape. - Q12: replace `array_position(make_array(...), col)` with `F.in_list(col("l_shipmode"), [...])`. Same semantics, without routing through array construction / array search. - Q19: remove the pyarrow UDF that re-implemented a disjunctive predicate in Python. Build the same predicate in DataFusion by OR-combining one `in_list` + range-filter expression per brand. Keeps the per-brand constants in the existing `items_of_interest` dict. - Q20: use `F.starts_with` instead of an explicit substring slice. Replace the inner-join + `select(...).distinct()` tail with a semi join against a precomputed set of excess-quantity suppliers so the supplier columns are preserved without deduplication after the fact. - Q21: replace the `array_agg` / `array_length` / `array_element` pipeline with two semi joins. One semi join keeps orders with more than one distinct supplier (stand-in for the reference SQL's `exists` subquery), the other keeps orders with exactly one late supplier (stand-in for the `not exists` subquery). All 22 answer-file comparisons and 22 plan-comparison diagnostics still pass (`pytest examples/tpch/_tests.py`: 44 passed). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e808db8 commit 91f96cb

7 files changed

Lines changed: 131 additions & 172 deletions

examples/tpch/q04_order_priority_checking.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@
7777

7878
interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval())
7979

80-
# Limit results to cases where commitment date before receipt date
81-
# Aggregate the results so we only get one row to join with the order table.
82-
# Alternately, and likely more idiomatic is instead of `.aggregate` you could
83-
# do `.select("l_orderkey").distinct()`. The goal here is to show
84-
# multiple examples of how to use Data Fusion.
85-
df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate(
86-
[col("l_orderkey")], []
80+
# Limit results to cases where commitment date before receipt date, then
81+
# reduce to a single row per order so the join with the orders table is a
82+
# semantic EXISTS rather than a fan-out.
83+
df_lineitem = (
84+
df_lineitem.filter(col("l_commitdate") < col("l_receiptdate"))
85+
.select("l_orderkey")
86+
.distinct()
8787
)
8888

8989
# Limit orders to date range of interest

examples/tpch/q07_volume_shipping.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,8 @@
116116
)
117117

118118

119-
# A simpler way to do the following operation is to use a filter, but we also want to demonstrate
120-
# how to use case statements. Here we are assigning `n_name` to be itself when it is either of
121-
# the two nations of interest. Since there is no `otherwise()` statement, any values that do
122-
# not match these will result in a null value and then get filtered out.
123-
#
124-
# To do the same using a simple filter would be:
125-
# df_nation = df_nation.filter((F.col("n_name") == nation_1) | (F.col("n_name") == nation_2)) # noqa: ERA001
126-
df_nation = df_nation.with_column(
127-
"n_name",
128-
F.case(col("n_name"))
129-
.when(nation_1, col("n_name"))
130-
.when(nation_2, col("n_name"))
131-
.end(),
132-
).filter(~col("n_name").is_null())
119+
# Limit the nation table to the two nations of interest.
120+
df_nation = df_nation.filter(F.in_list(col("n_name"), [nation_1, nation_2]))
133121

134122

135123
# Limit suppliers to either nation

examples/tpch/q08_market_share.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,13 @@
186186
df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left"
187187
)
188188

189-
# Use a case statement to compute the volume sold by suppliers in the nation of interest
189+
# Use a searched CASE (``F.when(...).otherwise(...)``) to keep only the
190+
# volume attributable to suppliers in the nation of interest. This mirrors
191+
# the ``case when nation = '...' then volume else 0 end`` form of the
192+
# reference SQL rather than dispatching on a boolean subject.
190193
df = df.with_column(
191194
"national_volume",
192-
F.case(col("s_suppkey").is_null())
193-
.when(lit(value=False), col("volume"))
194-
.otherwise(lit(0.0)),
195+
F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise(lit(0.0)),
195196
)
196197

197198
df = df.with_column(

examples/tpch/q12_ship_mode_order_priority.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,9 @@
9191
col("l_receiptdate") < lit(date) + lit(interval)
9292
)
9393

94-
# Note: It is not recommended to use array_has because it treats the second argument as an argument
95-
# so if you pass it col("l_shipmode") it will pass the entire array to process which is very slow.
96-
# Instead check the position of the entry is not null.
97-
df = df.filter(
98-
~F.array_position(
99-
F.make_array(lit(SHIP_MODE_1), lit(SHIP_MODE_2)), col("l_shipmode")
100-
).is_null()
101-
)
102-
103-
# Since we have only two values, it's much easier to do this as a filter where the l_shipmode
104-
# matches either of the two values, but we want to show doing some array operations in this
105-
# example. If you want to see this done with filters, comment out the above line and uncomment
106-
# this one.
107-
# df = df.filter((col("l_shipmode") == lit(SHIP_MODE_1)) | (col("l_shipmode") == lit(SHIP_MODE_2))) # noqa: ERA001
94+
# Restrict to the two ship modes of interest. ``in_list`` maps directly to
95+
# the ``l_shipmode in ('FOB', 'SHIP')`` clause of the reference SQL.
96+
df = df.filter(F.in_list(col("l_shipmode"), [lit(SHIP_MODE_1), lit(SHIP_MODE_2)]))
10897

10998

11099
# We need order priority, so join order df to line item

examples/tpch/q19_discounted_revenue.py

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464
);
6565
"""
6666

67-
import pyarrow as pa
68-
from datafusion import SessionContext, col, lit, udf
67+
from datafusion import SessionContext, col, lit
6968
from datafusion import functions as F
7069
from util import get_data_path
7170

@@ -114,59 +113,34 @@
114113
df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner")
115114

116115

117-
# Create the user defined function (UDF) definition that does the work
118-
def is_of_interest(
119-
brand_arr: pa.Array,
120-
container_arr: pa.Array,
121-
quantity_arr: pa.Array,
122-
size_arr: pa.Array,
123-
) -> pa.Array:
124-
"""
125-
The purpose of this function is to demonstrate how a UDF works, taking as input a pyarrow Array
126-
and generating a resultant Array. The length of the inputs should match and there should be the
127-
same number of rows in the output.
128-
"""
129-
result = []
130-
for idx, brand_val in enumerate(brand_arr):
131-
brand = brand_val.as_py()
132-
if brand in items_of_interest:
133-
values_of_interest = items_of_interest[brand]
134-
135-
container_matches = (
136-
container_arr[idx].as_py() in values_of_interest["containers"]
137-
)
138-
139-
quantity = quantity_arr[idx].as_py()
140-
quantity_matches = (
141-
values_of_interest["min_quantity"]
142-
<= quantity
143-
<= values_of_interest["min_quantity"] + 10
144-
)
145-
146-
size = size_arr[idx].as_py()
147-
size_matches = 1 <= size <= values_of_interest["max_size"]
148-
149-
result.append(container_matches and quantity_matches and size_matches)
150-
else:
151-
result.append(False)
152-
153-
return pa.array(result)
154-
155-
156-
# Turn the above function into a UDF that DataFusion can understand
157-
is_of_interest_udf = udf(
158-
is_of_interest,
159-
[pa.utf8(), pa.utf8(), pa.decimal128(15, 2), pa.int32()],
160-
pa.bool_(),
161-
"stable",
162-
)
116+
# Build one OR-combined predicate per brand. Each disjunct encodes the
117+
# brand-specific container list, quantity window, and size range from the
118+
# reference SQL. This mirrors the SQL ``where (... brand A ...) or (... brand
119+
# B ...) or (... brand C ...)`` form directly, without a UDF.
120+
def _brand_predicate(
121+
brand: str, min_quantity: int, containers: list[str], max_size: int
122+
):
123+
return (
124+
(col("p_brand") == lit(brand))
125+
& F.in_list(col("p_container"), [lit(c) for c in containers])
126+
& (col("l_quantity") >= lit(min_quantity))
127+
& (col("l_quantity") <= lit(min_quantity + 10))
128+
& (col("p_size") >= lit(1))
129+
& (col("p_size") <= lit(max_size))
130+
)
163131

164-
# Filter results using the above UDF
165-
df = df.filter(
166-
is_of_interest_udf(
167-
col("p_brand"), col("p_container"), col("l_quantity"), col("p_size")
132+
133+
predicate = None
134+
for brand, params in items_of_interest.items():
135+
part_predicate = _brand_predicate(
136+
brand,
137+
params["min_quantity"],
138+
params["containers"],
139+
params["max_size"],
168140
)
169-
)
141+
predicate = part_predicate if predicate is None else predicate | part_predicate
142+
143+
df = df.filter(predicate)
170144

171145
df = df.aggregate(
172146
[],

examples/tpch/q20_potential_part_promotion.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -100,42 +100,46 @@
100100

101101
interval = pa.scalar((0, 365, 0), type=pa.month_day_nano_interval())
102102

103-
# Filter down dataframes
103+
# Filter down dataframes. ``starts_with`` reads more naturally than an
104+
# explicit substring slice and maps directly to the reference SQL's
105+
# ``p_name like 'forest%'`` clause.
104106
df_nation = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST))
105-
df_part = df_part.filter(
106-
F.substring(col("p_name"), lit(0), lit(len(COLOR_OF_INTEREST) + 1))
107-
== lit(COLOR_OF_INTEREST)
107+
df_part = df_part.filter(F.starts_with(col("p_name"), lit(COLOR_OF_INTEREST)))
108+
109+
# Compute the total quantity of interesting parts shipped by each (part,
110+
# supplier) pair within the year of interest.
111+
totals = (
112+
df_lineitem.filter(col("l_shipdate") >= lit(date))
113+
.filter(col("l_shipdate") < lit(date) + lit(interval))
114+
.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner")
115+
.aggregate(
116+
[col("l_partkey"), col("l_suppkey")],
117+
[F.sum(col("l_quantity")).alias("total_sold")],
118+
)
108119
)
109120

110-
df = df_lineitem.filter(col("l_shipdate") >= lit(date)).filter(
111-
col("l_shipdate") < lit(date) + lit(interval)
121+
# Keep only (part, supplier) pairs whose available quantity exceeds 50% of
122+
# the total shipped. The result already contains one row per supplier of
123+
# interest, so we can semi-join the supplier table rather than inner-join
124+
# and deduplicate afterwards.
125+
excess_suppliers = (
126+
df_partsupp.join(
127+
totals,
128+
left_on=["ps_partkey", "ps_suppkey"],
129+
right_on=["l_partkey", "l_suppkey"],
130+
how="inner",
131+
)
132+
.filter(col("ps_availqty") > lit(0.5) * col("total_sold"))
133+
.select(col("ps_suppkey").alias("suppkey"))
134+
.distinct()
112135
)
113136

114-
# This will filter down the line items to the parts of interest
115-
df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner")
137+
# Limit to suppliers in the nation of interest and pick out the two
138+
# requested columns.
139+
df = df_supplier.join(
140+
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
141+
).join(excess_suppliers, left_on="s_suppkey", right_on="suppkey", how="semi")
116142

117-
# Compute the total sold and limit ourselves to individual supplier/part combinations
118-
df = df.aggregate(
119-
[col("l_partkey"), col("l_suppkey")], [F.sum(col("l_quantity")).alias("total_sold")]
120-
)
121-
122-
df = df.join(
123-
df_partsupp,
124-
left_on=["l_partkey", "l_suppkey"],
125-
right_on=["ps_partkey", "ps_suppkey"],
126-
how="inner",
127-
)
128-
129-
# Find cases of excess quantity
130-
df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold"))
131-
132-
# We could do these joins earlier, but now limit to the nation of interest suppliers
133-
df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner")
134-
df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner")
135-
136-
# Restrict to the requested data per the problem statement
137-
df = df.select("s_name", "s_address").distinct()
138-
139-
df = df.sort(col("s_name").sort())
143+
df = df.select("s_name", "s_address").sort(col("s_name").sort())
140144

141145
df.show()

examples/tpch/q21_suppliers_kept_orders_waiting.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -92,65 +92,68 @@
9292
)
9393

9494
# Limit to suppliers in the nation of interest
95-
df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST))
96-
97-
df_suppliers_of_interest = df_suppliers_of_interest.join(
98-
df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner"
95+
df_suppliers_of_interest = df_nation.filter(
96+
col("n_name") == lit(NATION_OF_INTEREST)
97+
).join(df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner")
98+
99+
# Line items for orders that have status 'F'. This is the candidate set of
100+
# (order, supplier) pairs we reason about below.
101+
failed_order_lineitems = df_lineitem.join(
102+
df_orders.filter(col("o_orderstatus") == lit("F")),
103+
left_on="l_orderkey",
104+
right_on="o_orderkey",
105+
how="inner",
99106
)
100107

101-
# Find the failed orders and all their line items
102-
df = df_orders.filter(col("o_orderstatus") == lit("F"))
103-
104-
df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner")
105-
106-
# Identify the line items for which the order is failed due to.
107-
df = df.with_column(
108-
"failed_supp",
109-
F.case(col("l_receiptdate") > col("l_commitdate"))
110-
.when(lit(value=True), col("l_suppkey"))
111-
.end(),
108+
# Line items whose receipt was late. This corresponds to ``l1`` in the
109+
# reference SQL.
110+
late_lineitems = failed_order_lineitems.filter(
111+
col("l_receiptdate") > col("l_commitdate")
112112
)
113113

114-
# There are other ways we could do this but the purpose of this example is to work with rows where
115-
# an element is an array of values. In this case, we will create two columns of arrays. One will be
116-
# an array of all of the suppliers who made up this order. That way we can filter the dataframe for
117-
# only orders where this array is larger than one for multiple supplier orders. The second column
118-
# is all of the suppliers who failed to make their commitment. We can filter the second column for
119-
# arrays with size one. That combination will give us orders that had multiple suppliers where only
120-
# one failed. Use distinct=True in the blow aggregation so we don't get multiple line items from the
121-
# same supplier reported in either array.
122-
df = df.aggregate(
123-
[col("o_orderkey")],
124-
[
125-
F.array_agg(col("l_suppkey"), distinct=True).alias("all_suppliers"),
126-
F.array_agg(
127-
col("failed_supp"), filter=col("failed_supp").is_not_null(), distinct=True
128-
).alias("failed_suppliers"),
129-
],
114+
# Orders that had more than one distinct supplier. Expressed as
115+
# ``count(distinct l_suppkey) > 1``. Stands in for the reference SQL's
116+
# ``exists (... l2.l_suppkey <> l1.l_suppkey ...)`` subquery.
117+
multi_supplier_orders = (
118+
failed_order_lineitems.select("l_orderkey", "l_suppkey")
119+
.distinct()
120+
.aggregate([col("l_orderkey")], [F.count(col("l_suppkey")).alias("n_suppliers")])
121+
.filter(col("n_suppliers") > lit(1))
122+
.select("l_orderkey")
130123
)
131124

132-
# This is the check described above which will identify single failed supplier in a multiple
133-
# supplier order.
134-
df = df.filter(F.array_length(col("failed_suppliers")) == lit(1)).filter(
135-
F.array_length(col("all_suppliers")) > lit(1)
125+
# Orders where exactly one distinct supplier was late. Stands in for the
126+
# reference SQL's ``not exists (... l3.l_suppkey <> l1.l_suppkey and l3 is
127+
# also late ...)`` subquery: if only one supplier on the order was late,
128+
# nobody else on the same order was late.
129+
single_late_supplier_orders = (
130+
late_lineitems.select("l_orderkey", "l_suppkey")
131+
.distinct()
132+
.aggregate(
133+
[col("l_orderkey")], [F.count(col("l_suppkey")).alias("n_late_suppliers")]
134+
)
135+
.filter(col("n_late_suppliers") == lit(1))
136+
.select("l_orderkey")
136137
)
137138

138-
# Since we have an array we know is exactly one element long, we can extract that single value.
139-
df = df.select(
140-
col("o_orderkey"), F.array_element(col("failed_suppliers"), lit(1)).alias("suppkey")
139+
# Keep late line items whose order qualifies on both counts. Semi joins
140+
# preserve the left-side columns without fanning out on the right.
141+
df = late_lineitems.join(multi_supplier_orders, on="l_orderkey", how="semi").join(
142+
single_late_supplier_orders, on="l_orderkey", how="semi"
141143
)
142144

143-
# Join to the supplier of interest list for the nation of interest
144-
df = df.join(
145-
df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner"
145+
# Attach the supplier name for suppliers in the nation of interest, count
146+
# one row per qualifying order, and return the top 100.
147+
df = (
148+
df.join(
149+
df_suppliers_of_interest,
150+
left_on="l_suppkey",
151+
right_on="s_suppkey",
152+
how="inner",
153+
)
154+
.aggregate([col("s_name")], [F.count(col("l_orderkey")).alias("numwait")])
155+
.sort(col("numwait").sort(ascending=False), col("s_name").sort())
156+
.limit(100)
146157
)
147158

148-
# Count how many orders that supplier is the only failed supplier for
149-
df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")])
150-
151-
# Return in descending order
152-
df = df.sort(col("numwait").sort(ascending=False), col("s_name").sort())
153-
154-
df = df.limit(100)
155-
156159
df.show()

0 commit comments

Comments
 (0)