Skip to content

Commit 0357716

Browse files
timsaucerclaude
andauthored
tpch examples: rewrite queries idiomatically and embed reference SQL (#1504)
* tpch examples: add reference SQL to each query, fix Q20 - Append the canonical TPC-H reference SQL (from benchmarks/tpch/queries/) to each q01..q22 module docstring so readers can compare the DataFrame translation against the SQL at a glance. - Fix Q20: `df = df.filter(col("ps_availqty") > lit(0.5) * col("total_sold"))` was missing the assignment so the filter was dropped from the pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * tpch examples: rewrite non-idiomatic queries in idiomatic DataFrame form Rewrite the seven TPC-H example queries that did not demonstrate the idiomatic DataFrame pattern. The remaining queries (Q02/Q11/Q15/Q17/Q22, which use window functions in place of correlated subqueries) already are idiomatic and are left unchanged. - Q04: replace `.aggregate([col("l_orderkey")], [])` with `.select("l_orderkey").distinct()`, which is the natural way to express "reduce to one row per order" on a DataFrame. - Q07: remove the CASE-as-filter on `n_name` and use `F.in_list(col("n_name"), [nation_1, nation_2])` instead. Drops a comment block that admitted the filter form was simpler. - Q08: rewrite the switched CASE `F.case(...).when(lit(False), ...)` as a searched `F.when(col(...).is_not_null(), ...).otherwise(...)`. That mirrors the reference SQL's `case when ... then ... else 0 end` shape. - Q12: replace `array_position(make_array(...), col)` with `F.in_list(col("l_shipmode"), [...])`. Same semantics, without routing through array construction / array search. - Q19: remove the pyarrow UDF that re-implemented a disjunctive predicate in Python. Build the same predicate in DataFusion by OR-combining one `in_list` + range-filter expression per brand. Keeps the per-brand constants in the existing `items_of_interest` dict. - Q20: use `F.starts_with` instead of an explicit substring slice. Replace the inner-join + `select(...).distinct()` tail with a semi join against a precomputed set of excess-quantity suppliers so the supplier columns are preserved without deduplication after the fact. - Q21: replace the `array_agg` / `array_length` / `array_element` pipeline with two semi joins. One semi join keeps orders with more than one distinct supplier (stand-in for the reference SQL's `exists` subquery), the other keeps orders with exactly one late supplier (stand-in for the `not exists` subquery). All 22 answer-file comparisons and 22 plan-comparison diagnostics still pass (`pytest examples/tpch/_tests.py`: 44 passed). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * tpch examples: align reference SQL constants with DataFrame queries The reference SQL embedded in each q01..q22 module docstring was carried over verbatim from ``benchmarks/tpch/queries/`` and uses a different set of TPC-H substitution parameters than the DataFrame examples (answer-file-validated at scale factor 1). Update each reference SQL to use the substitution parameters the DataFrame uses, so both expressions describe the same query and would produce the same results against the same data. Constants aligned: - Q01: ``90 days`` cutoff (DataFrame ``DAYS_BEFORE_FINAL = 90``). - Q02: ``p_size = 15``, ``p_type like '%BRASS'``, ``r_name = 'EUROPE'``. - Q04: base date ``1993-07-01`` (``3 month`` interval preserved per the "quarter of a year" wording). - Q05: ``r_name = 'ASIA'``. - Q06: ``l_discount between 0.06 - 0.01 and 0.06 + 0.01``. - Q07: nations ``'FRANCE'`` / ``'GERMANY'``. - Q08: ``r_name = 'AMERICA'``, ``p_type = 'ECONOMY ANODIZED STEEL'``, inner-case ``nation = 'BRAZIL'``. - Q09: ``p_name like '%green%'``. - Q10: base date ``1993-10-01`` (``3 month`` interval preserved). - Q11: ``n_name = 'GERMANY'``. - Q12: ship modes ``('MAIL', 'SHIP')``, base date ``1994-01-01``. - Q13: ``o_comment not like '%special%requests%'``. - Q14: base date ``1995-09-01``. - Q15: base date ``1996-01-01``. - Q16: ``p_brand <> 'Brand#45'``, ``p_type not like 'MEDIUM POLISHED%'``, sizes ``(49, 14, 23, 45, 19, 3, 36, 9)``. - Q17: ``p_brand = 'Brand#23'``, ``p_container = 'MED BOX'``. - Q18: ``sum(l_quantity) > 300``. - Q19: brands ``Brand#12`` / ``Brand#23`` / ``Brand#34`` with the matching minimum quantities (1, 10, 20). - Q20: ``p_name like 'forest%'``, base date ``1994-01-01``, ``n_name = 'CANADA'``. - Q21: ``n_name = 'SAUDI ARABIA'``. - Q22: country codes ``('13', '31', '23', '29', '30', '18', '17')``. Interval units (month / year) are preserved where the problem-statement text reads "given quarter", "given year", "given month". Q01 keeps the literal "days" unit because the TPC-H problem statement itself describes the cutoff in days. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * tpch examples: apply SKILL.md idioms across all 22 queries Sweep every q01..q22 example for idiomatic DataFrame style as described in the repo-root SKILL.md: - ``col("x") == "s"`` in place of ``col("x") == lit("s")`` on comparison right-hand sides (auto-wrap applies). - Plain-name strings in ``select``/``aggregate``/``sort`` group/sort key lists when the key is a bare column. - Drop redundant ``how="inner"`` and single-element ``left_on``/``right_on`` list wrapping on equi-joins. - Collapse chained ``.filter(a).filter(b)`` runs into ``.filter(a, b)`` and chained ``.with_column`` runs into ``.with_columns(a=..., b=...)``. - ``df.sort_by(...)`` or plain-name ``df.sort(...)`` when no null-placement override is needed. - ``F.count_star()`` in place of ``F.count(col("x"))`` whenever the SQL reads ``count(*)``. - ``F.starts_with(col, lit(prefix))`` and ``~F.starts_with(...)`` in place of substring-prefix equality/inequality tricks. - ``F.in_list(col, [lit(...)])`` in place of ``~F.array_position(...). is_null()`` and in place of disjunctions of equality comparisons. - Searched ``F.when(cond, x).otherwise(y)`` in place of switched ``F.case(bool_expr).when(lit(True/False), x).end()`` forms. - Semi-joins as the DataFrame form of ``EXISTS`` (Q04); anti-joins as ``NOT EXISTS`` (Q22 was already using this idiom). - Whole-frame window aggregates as the DataFrame stand-in for a SQL scalar subquery (Q11/Q15/Q17/Q22). Individual query fixes of note: - Q16 — add the secondary sort keys (``p_brand``, ``p_type``, ``p_size``) that the TPC-H spec requires but the original DataFrame omitted. - Q22 — drop a stray ``df.show()`` mid-pipeline; replace the 0-based substring slice with ``F.left(col("c_phone"), lit(2))``. - Q14 — rewrite the promo/non-promo factor split as a searched CASE inside ``F.sum(...)`` so the DataFrame expression matches the reference SQL shape exactly. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * tpch examples: more idiomatic aggregate FILTER, string funcs, date handling Additional sweep of the TPC-H DataFrame examples informed by comparing against a fresh set of SKILL.md-only generations under ``examples/tpch/agentic_queries/``: - Q02: ``F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST))`` in place of ``F.strpos(col, lit) > 0``. The reference SQL is ``p_type like '%BRASS'``, which is an ends_with check, not contains. ``F.strpos > 0`` returned the correct rows on TPC-H data by coincidence but is semantically wrong. - Q09: ``F.contains(col("p_name"), lit(part_color))`` in place of ``F.strpos(col, lit) > 0``. The SQL is ``p_name like '%green%'``. - Q08, Q12, Q14: use the ``filter`` keyword on ``F.sum`` / ``F.count`` — the DataFrame form of SQL ``sum(...) FILTER (WHERE ...)`` — instead of wrapping the aggregate input in ``F.when(cond, x).otherwise(0)``. Q08 also reorganises to inner-join the supplier's nation onto the regional sales, which removes the previous left-join + ``F.when(is_not_null, ...)`` dance. - Q15: compute the grand maximum revenue as a separate scalar aggregate and ``join_on(...)`` on equality, instead of the whole-frame window ``F.max`` + filter shape. Simpler plan, same result. - Q16: ``F.regexp_like(col, pattern)`` in place of ``F.regexp_match(col, pattern).is_not_null()``. - Q04, Q05, Q06, Q07, Q08, Q10, Q12, Q14, Q15, Q20: store both the start and the end of the date window as plain ``datetime.date`` objects and compare with ``lit(end_date)``, instead of carrying the start date + ``pa.month_day_nano_interval`` and adding them at query-build time. Drops unused ``pyarrow`` imports from the files that no longer need Arrow scalars. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c8bb9f7 commit 0357716

22 files changed

Lines changed: 1196 additions & 756 deletions

examples/tpch/q01_pricing_summary_report.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,30 @@
2727
2828
The above problem statement text is copyrighted by the Transaction Processing Performance Council
2929
as part of their TPC Benchmark H Specification revision 2.18.0.
30+
31+
Reference SQL (from TPC-H specification, used by the benchmark suite)::
32+
33+
select
34+
l_returnflag,
35+
l_linestatus,
36+
sum(l_quantity) as sum_qty,
37+
sum(l_extendedprice) as sum_base_price,
38+
sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
39+
sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
40+
avg(l_quantity) as avg_qty,
41+
avg(l_extendedprice) as avg_price,
42+
avg(l_discount) as avg_disc,
43+
count(*) as count_order
44+
from
45+
lineitem
46+
where
47+
l_shipdate <= date '1998-12-01' - interval '90 days'
48+
group by
49+
l_returnflag,
50+
l_linestatus
51+
order by
52+
l_returnflag,
53+
l_linestatus;
3054
"""
3155

3256
import pyarrow as pa
@@ -58,31 +82,25 @@
5882

5983
# Aggregate the results
6084

85+
disc_price = col("l_extendedprice") * (lit(1) - col("l_discount"))
86+
6187
df = df.aggregate(
62-
[col("l_returnflag"), col("l_linestatus")],
88+
["l_returnflag", "l_linestatus"],
6389
[
6490
F.sum(col("l_quantity")).alias("sum_qty"),
6591
F.sum(col("l_extendedprice")).alias("sum_base_price"),
66-
F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias(
67-
"sum_disc_price"
68-
),
69-
F.sum(
70-
col("l_extendedprice")
71-
* (lit(1) - col("l_discount"))
72-
* (lit(1) + col("l_tax"))
73-
).alias("sum_charge"),
92+
F.sum(disc_price).alias("sum_disc_price"),
93+
F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"),
7494
F.avg(col("l_quantity")).alias("avg_qty"),
7595
F.avg(col("l_extendedprice")).alias("avg_price"),
7696
F.avg(col("l_discount")).alias("avg_disc"),
77-
F.count(col("l_returnflag")).alias(
78-
"count_order"
79-
), # Counting any column should return same result
97+
F.count_star().alias("count_order"),
8098
],
8199
)
82100

83101
# Sort per the expected result
84102

85-
df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort())
103+
df = df.sort_by("l_returnflag", "l_linestatus")
86104

87105
# Note: There appears to be a discrepancy between what is returned here and what is in the generated
88106
# answers file for the case of return flag N and line status O, but I did not investigate further.

examples/tpch/q02_minimum_cost_supplier.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,52 @@
2727
2828
The above problem statement text is copyrighted by the Transaction Processing Performance Council
2929
as part of their TPC Benchmark H Specification revision 2.18.0.
30+
31+
Reference SQL (from TPC-H specification, used by the benchmark suite)::
32+
33+
select
34+
s_acctbal,
35+
s_name,
36+
n_name,
37+
p_partkey,
38+
p_mfgr,
39+
s_address,
40+
s_phone,
41+
s_comment
42+
from
43+
part,
44+
supplier,
45+
partsupp,
46+
nation,
47+
region
48+
where
49+
p_partkey = ps_partkey
50+
and s_suppkey = ps_suppkey
51+
and p_size = 15
52+
and p_type like '%BRASS'
53+
and s_nationkey = n_nationkey
54+
and n_regionkey = r_regionkey
55+
and r_name = 'EUROPE'
56+
and ps_supplycost = (
57+
select
58+
min(ps_supplycost)
59+
from
60+
partsupp,
61+
supplier,
62+
nation,
63+
region
64+
where
65+
p_partkey = ps_partkey
66+
and s_suppkey = ps_suppkey
67+
and s_nationkey = n_nationkey
68+
and n_regionkey = r_regionkey
69+
and r_name = 'EUROPE'
70+
)
71+
order by
72+
s_acctbal desc,
73+
n_name,
74+
s_name,
75+
p_partkey limit 100;
3076
"""
3177

3278
import datafusion
@@ -67,35 +113,30 @@
67113
"r_regionkey", "r_name"
68114
)
69115

70-
# Filter down parts. Part names contain the type of interest, so we can use strpos to find where
71-
# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position
72-
# in the string where it is located.
116+
# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which
117+
# is an ``ends_with`` check; use the dedicated string function rather than
118+
# a manual substring match.
73119

74120
df_part = df_part.filter(
75-
F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0)
76-
).filter(col("p_size") == lit(SIZE_OF_INTEREST))
121+
F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)),
122+
col("p_size") == SIZE_OF_INTEREST,
123+
)
77124

78125
# Filter regions down to the one of interest
79126

80-
df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST))
127+
df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST)
81128

82129
# Now that we have the region, find suppliers in that region. Suppliers are tied to their nation
83130
# and nations are tied to the region.
84131

85-
df_nation = df_nation.join(
86-
df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner"
87-
)
88-
df_supplier = df_supplier.join(
89-
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
90-
)
132+
df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey")
133+
df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey")
91134

92135
# Now that we know who the potential suppliers are for the part, we can limit out part
93136
# supplies table down. We can further join down to the specific parts we've identified
94137
# as matching the request
95138

96-
df = df_partsupp.join(
97-
df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner"
98-
)
139+
df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey")
99140

100141
# Locate the minimum cost across all suppliers. There are multiple ways you could do this,
101142
# but one way is to create a window function across all suppliers, find the minimum, and
@@ -112,9 +153,9 @@
112153
),
113154
)
114155

115-
df = df.filter(col("min_cost") == col("ps_supplycost"))
116-
117-
df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner")
156+
df = df.filter(col("min_cost") == col("ps_supplycost")).join(
157+
df_part, left_on="ps_partkey", right_on="p_partkey"
158+
)
118159

119160
# From the problem statement, these are the values we wish to output
120161

@@ -132,12 +173,10 @@
132173
# Sort and display 100 entries
133174
df = df.sort(
134175
col("s_acctbal").sort(ascending=False),
135-
col("n_name").sort(),
136-
col("s_name").sort(),
137-
col("p_partkey").sort(),
138-
)
139-
140-
df = df.limit(100)
176+
"n_name",
177+
"s_name",
178+
"p_partkey",
179+
).limit(100)
141180

142181
# Show results
143182

examples/tpch/q03_shipping_priority.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@
2525
2626
The above problem statement text is copyrighted by the Transaction Processing Performance Council
2727
as part of their TPC Benchmark H Specification revision 2.18.0.
28+
29+
Reference SQL (from TPC-H specification, used by the benchmark suite)::
30+
31+
select
32+
l_orderkey,
33+
sum(l_extendedprice * (1 - l_discount)) as revenue,
34+
o_orderdate,
35+
o_shippriority
36+
from
37+
customer,
38+
orders,
39+
lineitem
40+
where
41+
c_mktsegment = 'BUILDING'
42+
and c_custkey = o_custkey
43+
and l_orderkey = o_orderkey
44+
and o_orderdate < date '1995-03-15'
45+
and l_shipdate > date '1995-03-15'
46+
group by
47+
l_orderkey,
48+
o_orderdate,
49+
o_shippriority
50+
order by
51+
revenue desc,
52+
o_orderdate limit 10;
2853
"""
2954

3055
from datafusion import SessionContext, col, lit
@@ -50,38 +75,34 @@
5075

5176
# Limit dataframes to the rows of interest
5277

53-
df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST))
78+
df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST)
5479
df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST))
5580
df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST))
5681

5782
# Join all 3 dataframes
5883

59-
df = df_customer.join(
60-
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
61-
).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
84+
df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join(
85+
df_lineitem, left_on="o_orderkey", right_on="l_orderkey"
86+
)
6287

6388
# Compute the revenue
6489

6590
df = df.aggregate(
66-
[col("l_orderkey")],
91+
["l_orderkey"],
6792
[
6893
F.first_value(col("o_orderdate")).alias("o_orderdate"),
6994
F.first_value(col("o_shippriority")).alias("o_shippriority"),
7095
F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue"),
7196
],
7297
)
7398

74-
# Sort by priority
75-
76-
df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort())
77-
78-
# Only return 10 results
99+
# Sort by priority, take 10, and project in the order expected by the spec.
79100

80-
df = df.limit(10)
81-
82-
# Change the order that the columns are reported in just to match the spec
83-
84-
df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority")
101+
df = (
102+
df.sort(col("revenue").sort(ascending=False), "o_orderdate")
103+
.limit(10)
104+
.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority")
105+
)
85106

86107
# Show result
87108

examples/tpch/q04_order_priority_checking.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,40 @@
2424
2525
The above problem statement text is copyrighted by the Transaction Processing Performance Council
2626
as part of their TPC Benchmark H Specification revision 2.18.0.
27+
28+
Reference SQL (from TPC-H specification, used by the benchmark suite)::
29+
30+
select
31+
o_orderpriority,
32+
count(*) as order_count
33+
from
34+
orders
35+
where
36+
o_orderdate >= date '1993-07-01'
37+
and o_orderdate < date '1993-07-01' + interval '3' month
38+
and exists (
39+
select
40+
*
41+
from
42+
lineitem
43+
where
44+
l_orderkey = o_orderkey
45+
and l_commitdate < l_receiptdate
46+
)
47+
group by
48+
o_orderpriority
49+
order by
50+
o_orderpriority;
2751
"""
2852

29-
from datetime import datetime
53+
from datetime import date
3054

31-
import pyarrow as pa
3255
from datafusion import SessionContext, col, lit
3356
from datafusion import functions as F
3457
from util import get_data_path
3558

36-
# Ideally we could put 3 months into the interval. See note below.
37-
INTERVAL_DAYS = 92
38-
DATE_OF_INTEREST = "1993-07-01"
59+
QUARTER_START = date(1993, 7, 1)
60+
QUARTER_END = date(1993, 10, 1)
3961

4062
# Load the dataframes we need
4163

@@ -48,36 +70,23 @@
4870
"l_orderkey", "l_commitdate", "l_receiptdate"
4971
)
5072

51-
# Create a date object from the string
52-
date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date()
53-
54-
interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval())
55-
56-
# Limit results to cases where commitment date before receipt date
57-
# Aggregate the results so we only get one row to join with the order table.
58-
# Alternately, and likely more idiomatic is instead of `.aggregate` you could
59-
# do `.select("l_orderkey").distinct()`. The goal here is to show
60-
# multiple examples of how to use Data Fusion.
61-
df_lineitem = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate")).aggregate(
62-
[col("l_orderkey")], []
73+
# Keep only orders in the quarter of interest, then restrict to those that
74+
# have at least one late lineitem via a semi join (the DataFrame form of
75+
# ``EXISTS`` from the reference SQL).
76+
df_orders = df_orders.filter(
77+
col("o_orderdate") >= lit(QUARTER_START),
78+
col("o_orderdate") < lit(QUARTER_END),
6379
)
6480

65-
# Limit orders to date range of interest
66-
df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter(
67-
col("o_orderdate") < lit(date) + lit(interval)
68-
)
81+
late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate"))
6982

70-
# Perform the join to find only orders for which there are lineitems outside of expected range
7183
df = df_orders.join(
72-
df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner"
84+
late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi"
7385
)
7486

75-
# Based on priority, find the number of entries
76-
df = df.aggregate(
77-
[col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")]
87+
# Count the number of orders in each priority group and sort.
88+
df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by(
89+
"o_orderpriority"
7890
)
7991

80-
# Sort the results
81-
df = df.sort(col("o_orderpriority").sort())
82-
8392
df.show()

0 commit comments

Comments
 (0)