Skip to content

Commit a0c0fb9

Browse files
timsaucerclaude
andcommitted
tpch examples: apply SKILL.md idioms across all 22 queries
Sweep every q01..q22 example for idiomatic DataFrame style as described in the repo-root SKILL.md: - ``col("x") == "s"`` in place of ``col("x") == lit("s")`` on comparison right-hand sides (auto-wrap applies). - Plain-name strings in ``select``/``aggregate``/``sort`` group/sort key lists when the key is a bare column. - Drop redundant ``how="inner"`` and single-element ``left_on``/``right_on`` list wrapping on equi-joins. - Collapse chained ``.filter(a).filter(b)`` runs into ``.filter(a, b)`` and chained ``.with_column`` runs into ``.with_columns(a=..., b=...)``. - ``df.sort_by(...)`` or plain-name ``df.sort(...)`` when no null-placement override is needed. - ``F.count_star()`` in place of ``F.count(col("x"))`` whenever the SQL reads ``count(*)``. - ``F.starts_with(col, lit(prefix))`` and ``~F.starts_with(...)`` in place of substring-prefix equality/inequality tricks. - ``F.in_list(col, [lit(...)])`` in place of ``~F.array_position(...). is_null()`` and in place of disjunctions of equality comparisons. - Searched ``F.when(cond, x).otherwise(y)`` in place of switched ``F.case(bool_expr).when(lit(True/False), x).end()`` forms. - Semi-joins as the DataFrame form of ``EXISTS`` (Q04); anti-joins as ``NOT EXISTS`` (Q22 was already using this idiom). - Whole-frame window aggregates as the DataFrame stand-in for a SQL scalar subquery (Q11/Q15/Q17/Q22). Individual query fixes of note: - Q16 — add the secondary sort keys (``p_brand``, ``p_type``, ``p_size``) that the TPC-H spec requires but the original DataFrame omitted. - Q22 — drop a stray ``df.show()`` mid-pipeline; replace the 0-based substring slice with ``F.left(col("c_phone"), lit(2))``. - Q14 — rewrite the promo/non-promo factor split as a searched CASE inside ``F.sum(...)`` so the DataFrame expression matches the reference SQL shape exactly. All 22 answer-file comparisons still pass at scale factor 1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1878a46 commit a0c0fb9

22 files changed

Lines changed: 374 additions & 522 deletions

examples/tpch/q01_pricing_summary_report.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,31 +82,25 @@
8282

8383
# Aggregate the results
8484

85+
disc_price = col("l_extendedprice") * (lit(1) - col("l_discount"))
86+
8587
df = df.aggregate(
86-
[col("l_returnflag"), col("l_linestatus")],
88+
["l_returnflag", "l_linestatus"],
8789
[
8890
F.sum(col("l_quantity")).alias("sum_qty"),
8991
F.sum(col("l_extendedprice")).alias("sum_base_price"),
90-
F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias(
91-
"sum_disc_price"
92-
),
93-
F.sum(
94-
col("l_extendedprice")
95-
* (lit(1) - col("l_discount"))
96-
* (lit(1) + col("l_tax"))
97-
).alias("sum_charge"),
92+
F.sum(disc_price).alias("sum_disc_price"),
93+
F.sum(disc_price * (lit(1) + col("l_tax"))).alias("sum_charge"),
9894
F.avg(col("l_quantity")).alias("avg_qty"),
9995
F.avg(col("l_extendedprice")).alias("avg_price"),
10096
F.avg(col("l_discount")).alias("avg_disc"),
101-
F.count(col("l_returnflag")).alias(
102-
"count_order"
103-
), # Counting any column should return same result
97+
F.count_star().alias("count_order"),
10498
],
10599
)
106100

107101
# Sort per the expected result
108102

109-
df = df.sort(col("l_returnflag").sort(), col("l_linestatus").sort())
103+
df = df.sort_by("l_returnflag", "l_linestatus")
110104

111105
# Note: There appears to be a discrepancy between what is returned here and what is in the generated
112106
# answers file for the case of return flag N and line status O, but I did not investigate further.

examples/tpch/q02_minimum_cost_supplier.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -118,30 +118,25 @@
118118
# in the string where it is located.
119119

120120
df_part = df_part.filter(
121-
F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > lit(0)
122-
).filter(col("p_size") == lit(SIZE_OF_INTEREST))
121+
F.strpos(col("p_type"), lit(TYPE_OF_INTEREST)) > 0,
122+
col("p_size") == SIZE_OF_INTEREST,
123+
)
123124

124125
# Filter regions down to the one of interest
125126

126-
df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST))
127+
df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST)
127128

128129
# Now that we have the region, find suppliers in that region. Suppliers are tied to their nation
129130
# and nations are tied to the region.
130131

131-
df_nation = df_nation.join(
132-
df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner"
133-
)
134-
df_supplier = df_supplier.join(
135-
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
136-
)
132+
df_nation = df_nation.join(df_region, left_on="n_regionkey", right_on="r_regionkey")
133+
df_supplier = df_supplier.join(df_nation, left_on="s_nationkey", right_on="n_nationkey")
137134

138135
# Now that we know who the potential suppliers are for the part, we can limit out part
139136
# supplies table down. We can further join down to the specific parts we've identified
140137
# as matching the request
141138

142-
df = df_partsupp.join(
143-
df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner"
144-
)
139+
df = df_partsupp.join(df_supplier, left_on="ps_suppkey", right_on="s_suppkey")
145140

146141
# Locate the minimum cost across all suppliers. There are multiple ways you could do this,
147142
# but one way is to create a window function across all suppliers, find the minimum, and
@@ -158,9 +153,9 @@
158153
),
159154
)
160155

161-
df = df.filter(col("min_cost") == col("ps_supplycost"))
162-
163-
df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner")
156+
df = df.filter(col("min_cost") == col("ps_supplycost")).join(
157+
df_part, left_on="ps_partkey", right_on="p_partkey"
158+
)
164159

165160
# From the problem statement, these are the values we wish to output
166161

@@ -178,12 +173,10 @@
178173
# Sort and display 100 entries
179174
df = df.sort(
180175
col("s_acctbal").sort(ascending=False),
181-
col("n_name").sort(),
182-
col("s_name").sort(),
183-
col("p_partkey").sort(),
184-
)
185-
186-
df = df.limit(100)
176+
"n_name",
177+
"s_name",
178+
"p_partkey",
179+
).limit(100)
187180

188181
# Show results
189182

examples/tpch/q03_shipping_priority.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,38 +75,34 @@
7575

7676
# Limit dataframes to the rows of interest
7777

78-
df_customer = df_customer.filter(col("c_mktsegment") == lit(SEGMENT_OF_INTEREST))
78+
df_customer = df_customer.filter(col("c_mktsegment") == SEGMENT_OF_INTEREST)
7979
df_orders = df_orders.filter(col("o_orderdate") < lit(DATE_OF_INTEREST))
8080
df_lineitem = df_lineitem.filter(col("l_shipdate") > lit(DATE_OF_INTEREST))
8181

8282
# Join all 3 dataframes
8383

84-
df = df_customer.join(
85-
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
86-
).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
84+
df = df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey").join(
85+
df_lineitem, left_on="o_orderkey", right_on="l_orderkey"
86+
)
8787

8888
# Compute the revenue
8989

9090
df = df.aggregate(
91-
[col("l_orderkey")],
91+
["l_orderkey"],
9292
[
9393
F.first_value(col("o_orderdate")).alias("o_orderdate"),
9494
F.first_value(col("o_shippriority")).alias("o_shippriority"),
9595
F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue"),
9696
],
9797
)
9898

99-
# Sort by priority
100-
101-
df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort())
102-
103-
# Only return 10 results
99+
# Sort by priority, take 10, and project in the order expected by the spec.
104100

105-
df = df.limit(10)
106-
107-
# Change the order that the columns are reported in just to match the spec
108-
109-
df = df.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority")
101+
df = (
102+
df.sort(col("revenue").sort(ascending=False), "o_orderdate")
103+
.limit(10)
104+
.select("l_orderkey", "revenue", "o_orderdate", "o_shippriority")
105+
)
110106

111107
# Show result
112108

examples/tpch/q04_order_priority_checking.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,31 +77,23 @@
7777

7878
interval = pa.scalar((0, INTERVAL_DAYS, 0), type=pa.month_day_nano_interval())
7979

80-
# Limit results to cases where commitment date before receipt date, then
81-
# reduce to a single row per order so the join with the orders table is a
82-
# semantic EXISTS rather than a fan-out.
83-
df_lineitem = (
84-
df_lineitem.filter(col("l_commitdate") < col("l_receiptdate"))
85-
.select("l_orderkey")
86-
.distinct()
80+
# Keep only orders in the quarter of interest, then restrict to those that
81+
# have at least one late lineitem via a semi join (the DataFrame form of
82+
# ``EXISTS`` from the reference SQL).
83+
df_orders = df_orders.filter(
84+
col("o_orderdate") >= lit(date),
85+
col("o_orderdate") < lit(date) + lit(interval),
8786
)
8887

89-
# Limit orders to date range of interest
90-
df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter(
91-
col("o_orderdate") < lit(date) + lit(interval)
92-
)
88+
late_lineitems = df_lineitem.filter(col("l_commitdate") < col("l_receiptdate"))
9389

94-
# Perform the join to find only orders for which there are lineitems outside of expected range
9590
df = df_orders.join(
96-
df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner"
91+
late_lineitems, left_on="o_orderkey", right_on="l_orderkey", how="semi"
9792
)
9893

99-
# Based on priority, find the number of entries
100-
df = df.aggregate(
101-
[col("o_orderpriority")], [F.count(col("o_orderpriority")).alias("order_count")]
94+
# Count the number of orders in each priority group and sort.
95+
df = df.aggregate(["o_orderpriority"], [F.count_star().alias("order_count")]).sort_by(
96+
"o_orderpriority"
10297
)
10398

104-
# Sort the results
105-
df = df.sort(col("o_orderpriority").sort())
106-
10799
df.show()

examples/tpch/q05_local_supplier_volume.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,38 +95,32 @@
9595
)
9696

9797
# Restrict dataframes to cases of interest
98-
df_orders = df_orders.filter(col("o_orderdate") >= lit(date)).filter(
99-
col("o_orderdate") < lit(date) + lit(interval)
98+
df_orders = df_orders.filter(
99+
col("o_orderdate") >= lit(date),
100+
col("o_orderdate") < lit(date) + lit(interval),
100101
)
101102

102-
df_region = df_region.filter(col("r_name") == lit(REGION_OF_INTEREST))
103+
df_region = df_region.filter(col("r_name") == REGION_OF_INTEREST)
103104

104105
# Join all the dataframes
105106

106107
df = (
107-
df_customer.join(
108-
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
109-
)
110-
.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
108+
df_customer.join(df_orders, left_on="c_custkey", right_on="o_custkey")
109+
.join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey")
111110
.join(
112111
df_supplier,
113112
left_on=["l_suppkey", "c_nationkey"],
114113
right_on=["s_suppkey", "s_nationkey"],
115-
how="inner",
116114
)
117-
.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner")
118-
.join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner")
115+
.join(df_nation, left_on="s_nationkey", right_on="n_nationkey")
116+
.join(df_region, left_on="n_regionkey", right_on="r_regionkey")
119117
)
120118

121-
# Compute the final result
119+
# Compute the final result, then sort in descending order.
122120

123121
df = df.aggregate(
124-
[col("n_name")],
122+
["n_name"],
125123
[F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias("revenue")],
126-
)
127-
128-
# Sort in descending order
129-
130-
df = df.sort(col("revenue").sort(ascending=False))
124+
).sort(col("revenue").sort(ascending=False))
131125

132126
df.show()

examples/tpch/q06_forecasting_revenue_change.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,11 @@
7171

7272
# Filter down to lineitems of interest
7373

74-
df = (
75-
df_lineitem.filter(col("l_shipdate") >= lit(date))
76-
.filter(col("l_shipdate") < lit(date) + lit(interval))
77-
.filter(col("l_discount") >= lit(DISCOUT) - lit(DELTA))
78-
.filter(col("l_discount") <= lit(DISCOUT) + lit(DELTA))
79-
.filter(col("l_quantity") < lit(QUANTITY))
74+
df = df_lineitem.filter(
75+
col("l_shipdate") >= lit(date),
76+
col("l_shipdate") < lit(date) + lit(interval),
77+
col("l_discount").between(lit(DISCOUT - DELTA), lit(DISCOUT + DELTA)),
78+
col("l_quantity") < QUANTITY,
8079
)
8180

8281
# Add up all the "lost" revenue

examples/tpch/q07_volume_shipping.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@
111111

112112

113113
# Filter to time of interest
114-
df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter(
115-
col("l_shipdate") <= end_date
114+
df_lineitem = df_lineitem.filter(
115+
col("l_shipdate") >= start_date, col("l_shipdate") <= end_date
116116
)
117117

118118

@@ -122,37 +122,33 @@
122122

123123
# Limit suppliers to either nation
124124
df_supplier = df_supplier.join(
125-
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
126-
).select(col("s_suppkey"), col("n_name").alias("supp_nation"))
125+
df_nation, left_on="s_nationkey", right_on="n_nationkey"
126+
).select("s_suppkey", col("n_name").alias("supp_nation"))
127127

128128
# Limit customers to either nation
129129
df_customer = df_customer.join(
130-
df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner"
131-
).select(col("c_custkey"), col("n_name").alias("cust_nation"))
130+
df_nation, left_on="c_nationkey", right_on="n_nationkey"
131+
).select("c_custkey", col("n_name").alias("cust_nation"))
132132

133133
# Join up all the data frames from line items, and make sure the supplier and customer are in
134134
# different nations.
135135
df = (
136-
df_lineitem.join(
137-
df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner"
138-
)
139-
.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")
140-
.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")
136+
df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey")
137+
.join(df_customer, left_on="o_custkey", right_on="c_custkey")
138+
.join(df_supplier, left_on="l_suppkey", right_on="s_suppkey")
141139
.filter(col("cust_nation") != col("supp_nation"))
142140
)
143141

144142
# Extract out two values for every line item
145-
df = df.with_column(
146-
"l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32())
147-
).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount")))
143+
df = df.with_columns(
144+
l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()),
145+
volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")),
146+
)
148147

149-
# Aggregate the results
148+
# Aggregate and sort per the spec.
150149
df = df.aggregate(
151-
[col("supp_nation"), col("cust_nation"), col("l_year")],
150+
["supp_nation", "cust_nation", "l_year"],
152151
[F.sum(col("volume")).alias("revenue")],
153-
)
154-
155-
# Sort based on problem statement requirements
156-
df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort())
152+
).sort_by("supp_nation", "cust_nation", "l_year")
157153

158154
df.show()

0 commit comments

Comments
 (0)