Skip to content

Commit 0d16d17

Browse files
timsaucerclaude
andcommitted
tpch examples: more idiomatic aggregate FILTER, string funcs, date handling
Additional sweep of the TPC-H DataFrame examples informed by comparing against a fresh set of SKILL.md-only generations under ``examples/tpch/agentic_queries/``: - Q02: ``F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST))`` in place of ``F.strpos(col, lit) > 0``. The reference SQL is ``p_type like '%BRASS'``, which is an ends_with check, not contains. ``F.strpos > 0`` returned the correct rows on TPC-H data by coincidence but is semantically wrong. - Q09: ``F.contains(col("p_name"), lit(part_color))`` in place of ``F.strpos(col, lit) > 0``. The SQL is ``p_name like '%green%'``. - Q08, Q12, Q14: use the ``filter`` keyword on ``F.sum`` / ``F.count`` — the DataFrame form of SQL ``sum(...) FILTER (WHERE ...)`` — instead of wrapping the aggregate input in ``F.when(cond, x).otherwise(0)``. Q08 also reorganises to inner-join the supplier's nation onto the regional sales, which removes the previous left-join + ``F.when(is_not_null, ...)`` dance. - Q15: compute the grand maximum revenue as a separate scalar aggregate and ``join_on(...)`` on equality, instead of the whole-frame window ``F.max`` + filter shape. Simpler plan, same result. - Q16: ``F.regexp_like(col, pattern)`` in place of ``F.regexp_match(col, pattern).is_not_null()``. - Q04, Q05, Q06, Q07, Q08, Q10, Q12, Q14, Q15, Q20: store both the start and the end of the date window as plain ``datetime.date`` objects and compare with ``lit(end_date)``, instead of carrying the start date + ``pa.month_day_nano_interval`` and adding them at query-build time. Drops unused ``pyarrow`` imports from the files that no longer need Arrow scalars. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a0c0fb9 commit 0d16d17

13 files changed

Lines changed: 128 additions & 190 deletions

examples/tpch/q02_minimum_cost_supplier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@
113113
"r_regionkey", "r_name"
114114
)
115115

116-
# Filter down parts. Part names contain the type of interest, so we can use strpos to find where
117-
# in the p_type column the word is. `strpos` will return 0 if not found, otherwise the position
118-
# in the string where it is located.
116+
# Filter down parts. The reference SQL uses ``p_type like '%BRASS'`` which
117+
# is an ``ends_with`` check; use the dedicated string function rather than
118+
# a manual substring match.
119119

120120
df_part = df_part.filter(
121-
F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > 0,
121+
F.ends_with(col("p_type"), lit(TYPE_OF_INTEREST)),
122122
col("p_size") == SIZE_OF_INTEREST,
123123
)
124124

examples/tpch/q04_order_priority_checking.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,14 @@
5050
o_orderpriority;
5151
"""
5252

53-
from datetime import datetime
53+
from datetime import date
5454

55-
import pyarrow as pa
5655
from datafusion import SessionContext, col, lit
5756
from datafusion import functions as F
5857
from util import get_data_path
5958

60-
# Ideally we could put 3 months into the interval. See note below.
61-
INTERVAL_DAYS = 92
62-
DATE_OF_INTEREST = "1993-07-01"
59+
QUARTER_START = date(1993, 7, 1)
60+
QUARTER_END = date(1993, 10, 1)
6361

6462
# Load the dataframes we need
6563

@@ -72,17 +70,12 @@
7270
"l_orderkey", "l_commitdate", "l_receiptdate"
7371
)
7472

75-
# Create a date object from the string
76-
date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date()
77-
78-
interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval())
79-
8073
# Keep only orders in the quarter of interest, then restrict to those that
8174
# have at least one late lineitem via a semi join (the DataFrame form of
8275
# ``EXISTS`` from the reference SQL).
8376
df_orders = df_orders.filter(
84-
col("o_orderdate") >= lit(date),
85-
col("o_orderdate") < lit(date) + lit(interval),
77+
col("o_orderdate") >= lit(QUARTER_START),
78+
col("o_orderdate") < lit(QUARTER_END),
8679
)
8780

8881
late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate"))

examples/tpch/q05_local_supplier_volume.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,16 @@
5656
revenue desc;
5757
"""
5858

59-
from datetime import datetime
59+
from datetime import date
6060

61-
import pyarrow as pa
6261
from datafusion import SessionContext, col, lit
6362
from datafusion import functions as F
6463
from util import get_data_path
6564

66-
DATE_OF_INTEREST = "1994-01-01"
67-
INTERVAL_DAYS = 365
65+
YEAR_START = date(1994, 1, 1)
66+
YEAR_END = date(1995, 1, 1)
6867
REGION_OF_INTEREST = "ASIA"
6968

70-
date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date()
71-
72-
interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval())
73-
7469
# Load the dataframes we need
7570

7671
ctx = SessionContext()
@@ -96,8 +91,8 @@
9691

9792
# Restrict dataframes to cases of interest
9893
df_orders = df_orders.filter(
99-
col("o_orderdate") >= lit(date),
100-
col("o_orderdate") < lit(date) + lit(interval),
94+
col("o_orderdate") >= lit(YEAR_START),
95+
col("o_orderdate") < lit(YEAR_END),
10196
)
10297

10398
df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST)

examples/tpch/q06_forecasting_revenue_change.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,20 @@
4141
and l_quantity < 24;
4242
"""
4343

44-
from datetime import datetime
44+
from datetime import date
4545

46-
import pyarrow as pa
4746
from datafusion import SessionContext, col, lit
4847
from datafusion import functions as F
4948
from util import get_data_path
5049

5150
# Variables from the example query
5251

53-
DATE_OF_INTEREST = "1994-01-01"
52+
YEAR_START = date(1994, 1, 1)
53+
YEAR_END = date(1995, 1, 1)
5454
DISCOUT = 0.06
5555
DELTA = 0.01
5656
QUANTITY = 24
5757

58-
INTERVAL_DAYS = 365
59-
60-
date = datetime.strptime(DATE_OF_INTEREST, "%Y-%m-%d").date()
61-
62-
interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval())
63-
6458
# Load the dataframes we need
6559

6660
ctx = SessionContext()
@@ -72,8 +66,8 @@
7266
# Filter down to lineitems of interest
7367

7468
df = df_lineitem.filter(
75-
col("l_shipdate") >= lit(date),
76-
col("l_shipdate") < lit(date) + lit(interval),
69+
col("l_shipdate") >= lit(YEAR_START),
70+
col("l_shipdate") < lit(YEAR_END),
7771
col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)),
7872
col("l_quantity") < QUANTITY,
7973
)

examples/tpch/q07_volume_shipping.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
l_year;
7171
"""
7272

73-
from datetime import datetime
73+
from datetime import date
7474

7575
import pyarrow as pa
7676
from datafusion import SessionContext, col, lit
@@ -82,11 +82,8 @@
8282
nation_1 = lit("FRANCE")
8383
nation_2 = lit("GERMANY")
8484

85-
START_DATE = "1995-01-01"
86-
END_DATE = "1996-12-31"
87-
88-
start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date())
89-
end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date())
85+
START_DATE = date(1995, 1, 1)
86+
END_DATE = date(1996, 12, 31)
9087

9188

9289
# Load the dataframes we need
@@ -112,7 +109,7 @@
112109

113110
# Filter to time of interest
114111
df_lineitem = df_lineitem.filter(
115-
col("l_shipdate") >= start_date, col("l_shipdate") <= end_date
112+
col("l_shipdate") >= lit(START_DATE), col("l_shipdate") <= lit(END_DATE)
116113
)
117114

118115

examples/tpch/q08_market_share.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,19 @@
6767
o_year;
6868
"""
6969

70-
from datetime import datetime
70+
from datetime import date
7171

7272
import pyarrow as pa
7373
from datafusion import SessionContext, col, lit
7474
from datafusion import functions as F
7575
from util import get_data_path
7676

77-
supplier_nation = lit("BRAZIL")
78-
customer_region = lit("AMERICA")
79-
part_of_interest = lit("ECONOMY ANODIZED STEEL")
77+
supplier_nation = "BRAZIL"
78+
customer_region = "AMERICA"
79+
part_of_interest = "ECONOMY ANODIZED STEEL"
8080

81-
START_DATE = "1995-01-01"
82-
END_DATE = "1996-12-31"
83-
84-
start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date())
85-
end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date())
81+
START_DATE = date(1995, 1, 1)
82+
END_DATE = date(1996, 12, 31)
8683

8784

8885
# Load the dataframes we need
@@ -115,67 +112,55 @@
115112
# Limit orders to those in the specified range
116113

117114
df_orders = df_orders.filter(
118-
col("o_orderdate") >= start_date, col("o_orderdate") <= end_date
115+
col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE)
119116
)
120117

121-
# Part 1: Find customers in the region
118+
# Pair each supplier with its nation name so every regional-customer row
119+
# below carries the supplier's nation and can be filtered inside the
120+
# aggregate with ``F.sum(..., filter=...)``.
122121

123-
# We want customers in region specified by region_of_interest. This will be used to compute
124-
# the total sales of the part of interest. We want to know of those sales what fraction
125-
# was supplied by the nation of interest. There is no guarantee that the nation of
126-
# interest is within the region of interest.
122+
df_supplier_with_nation = df_supplier.join(
123+
df_nation, left_on="s_nationkey", right_on="n_nationkey"
124+
).select("s_suppkey", col("n_name").alias("supp_nation"))
127125

128-
# First we find all the sales that make up the basis.
126+
# Build every (part, lineitem, order, customer) row for customers in the
127+
# target region ordering the target part. Each row carries the supplier's
128+
# nation so we can aggregate on it below.
129129

130-
df_regional_customers = (
130+
df = (
131131
df_region.filter(col("r_name") == customer_region)
132132
.join(df_nation, left_on="r_regionkey", right_on="n_regionkey")
133133
.join(df_customer, left_on="n_nationkey", right_on="c_nationkey")
134134
.join(df_orders, left_on="c_custkey", right_on="o_custkey")
135135
.join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey")
136136
.join(df_part, left_on="l_partkey", right_on="p_partkey")
137-
.with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")))
138-
)
139-
140-
# Part 2: Find suppliers from the nation
141-
142-
# Now that we have all of the sales of that part in the specified region, we need
143-
# to determine which of those came from suppliers in the nation we are interested in.
144-
145-
df_national_suppliers = (
146-
df_nation.filter(col("n_name") == supplier_nation)
147-
.join(df_supplier, left_on="n_nationkey", right_on="s_nationkey")
148-
.select("s_suppkey")
149-
)
150-
151-
152-
# Part 3: Combine suppliers and customers and compute the market share
153-
154-
# Left-outer join the national suppliers onto the regional sales. Rows from
155-
# other suppliers get a NULL ``s_suppkey``, which the CASE expression uses
156-
# to zero out the non-national volume.
157-
158-
df = df_regional_customers.join(
159-
df_national_suppliers, left_on="l_suppkey", right_on="s_suppkey", how="left"
160-
).with_columns(
161-
national_volume=F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise(
162-
lit(0.0)
163-
),
164-
o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()),
137+
.join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey")
138+
.with_columns(
139+
volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")),
140+
o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()),
141+
)
165142
)
166143

167-
168-
# Aggregate, compute the share, and sort.
169-
144+
# Aggregate the total and national volumes per year via the ``filter``
145+
# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``).
146+
# ``coalesce`` handles the case where no sale came from the target nation
147+
# for a given year.
170148
df = (
171149
df.aggregate(
172150
["o_year"],
173151
[
174-
F.sum(col("volume")).alias("volume"),
175-
F.sum(col("national_volume")).alias("national_volume"),
152+
F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias(
153+
"national_volume"
154+
),
155+
F.sum(col("volume")).alias("total_volume"),
176156
],
177157
)
178-
.select("o_year", (col("national_volume") / col("volume")).alias("mkt_share"))
158+
.select(
159+
"o_year",
160+
(F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias(
161+
"mkt_share"
162+
),
163+
)
179164
.sort_by("o_year")
180165
)
181166

examples/tpch/q09_product_type_profit_measure.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from datafusion import functions as F
7070
from util import get_data_path
7171

72-
part_color = lit("green")
72+
part_color = "green"
7373

7474
# Load the dataframes we need
7575

@@ -98,9 +98,10 @@
9898
)
9999

100100
# Limit possible parts to the color specified, then walk the joins down to the
101-
# line-item rows we need and attach the supplier's nation.
101+
# line-item rows we need and attach the supplier's nation. ``F.contains``
102+
# maps directly to the reference SQL's ``p_name like '%green%'``.
102103
df = (
103-
df_part.filter(F.strpos(col("p_name"), part_color) > 0)
104+
df_part.filter(F.contains(col("p_name"), lit(part_color)))
104105
.join(df_lineitem, left_on="p_partkey", right_on="l_partkey")
105106
.join(df_supplier, left_on="l_suppkey", right_on="s_suppkey")
106107
.join(df_orders, left_on="l_orderkey", right_on="o_orderkey")

examples/tpch/q10_returned_item_reporting.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,14 @@
6363
revenue desc limit 20;
6464
"""
6565

66-
from datetime import datetime
66+
from datetime import date
6767

68-
import pyarrow as pa
6968
from datafusion import SessionContext, col, lit
7069
from datafusion import functions as F
7170
from util import get_data_path
7271

73-
DATE_START_OF_QUARTER = "1993-10-01"
74-
75-
date_start_of_quarter = lit(datetime.strptime(DATE_START_OF_QUARTER, "%Y-%m-%d").date())
76-
77-
interval_one_quarter = lit(pa.scalar((0, 92, 0), type=pa.month_day_nano_interval()))
72+
QUARTER_START = date(1993, 10, 1)
73+
QUARTER_END = date(1994, 1, 1)
7874

7975
# Load the dataframes we need
8076

@@ -108,8 +104,8 @@
108104

109105
df = (
110106
df_orders.filter(
111-
col("o_orderdate") >= date_start_of_quarter,
112-
col("o_orderdate") < date_start_of_quarter + interval_one_quarter,
107+
col("o_orderdate") >= lit(QUARTER_START),
108+
col("o_orderdate") < lit(QUARTER_END),
113109
)
114110
.join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey")
115111
.aggregate(

0 commit comments

Comments
 (0)