|
67 | 67 | o_year; |
68 | 68 | """ |
69 | 69 |
|
70 | | -from datetime import datetime |
| 70 | +from datetime import date |
71 | 71 |
|
72 | 72 | import pyarrow as pa |
73 | 73 | from datafusion import SessionContext, col, lit |
74 | 74 | from datafusion import functions as F |
75 | 75 | from util import get_data_path |
76 | 76 |
|
77 | | -supplier_nation = lit("BRAZIL") |
78 | | -customer_region = lit("AMERICA") |
79 | | -part_of_interest = lit("ECONOMY ANODIZED STEEL") |
| 77 | +supplier_nation = "BRAZIL" |
| 78 | +customer_region = "AMERICA" |
| 79 | +part_of_interest = "ECONOMY ANODIZED STEEL" |
80 | 80 |
|
81 | | -START_DATE = "1995-01-01" |
82 | | -END_DATE = "1996-12-31" |
83 | | - |
84 | | -start_date = lit(datetime.strptime(START_DATE, "%Y-%m-%d").date()) |
85 | | -end_date = lit(datetime.strptime(END_DATE, "%Y-%m-%d").date()) |
| 81 | +START_DATE = date(1995, 1, 1) |
| 82 | +END_DATE = date(1996, 12, 31) |
86 | 83 |
|
87 | 84 |
|
88 | 85 | # Load the dataframes we need |
|
115 | 112 | # Limit orders to those in the specified range |
116 | 113 |
|
117 | 114 | df_orders = df_orders.filter( |
118 | | - col("o_orderdate") >= start_date, col("o_orderdate") <= end_date |
| 115 | + col("o_orderdate") >= lit(START_DATE), col("o_orderdate") <= lit(END_DATE) |
119 | 116 | ) |
120 | 117 |
|
121 | | -# Part 1: Find customers in the region |
| 118 | +# Pair each supplier with its nation name so every regional-customer row |
| 119 | +# below carries the supplier's nation and can be filtered inside the |
| 120 | +# aggregate with ``F.sum(..., filter=...)``. |
122 | 121 |
|
123 | | -# We want customers in region specified by region_of_interest. This will be used to compute |
124 | | -# the total sales of the part of interest. We want to know of those sales what fraction |
125 | | -# was supplied by the nation of interest. There is no guarantee that the nation of |
126 | | -# interest is within the region of interest. |
| 122 | +df_supplier_with_nation = df_supplier.join( |
| 123 | + df_nation, left_on="s_nationkey", right_on="n_nationkey" |
| 124 | +).select("s_suppkey", col("n_name").alias("supp_nation")) |
127 | 125 |
|
128 | | -# First we find all the sales that make up the basis. |
| 126 | +# Build every (part, lineitem, order, customer) row for customers in the |
| 127 | +# target region ordering the target part. Each row carries the supplier's |
| 128 | +# nation so we can aggregate on it below. |
129 | 129 |
|
130 | | -df_regional_customers = ( |
| 130 | +df = ( |
131 | 131 | df_region.filter(col("r_name") == customer_region) |
132 | 132 | .join(df_nation, left_on="r_regionkey", right_on="n_regionkey") |
133 | 133 | .join(df_customer, left_on="n_nationkey", right_on="c_nationkey") |
134 | 134 | .join(df_orders, left_on="c_custkey", right_on="o_custkey") |
135 | 135 | .join(df_lineitem, left_on="o_orderkey", right_on="l_orderkey") |
136 | 136 | .join(df_part, left_on="l_partkey", right_on="p_partkey") |
137 | | - .with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) |
138 | | -) |
139 | | - |
140 | | -# Part 2: Find suppliers from the nation |
141 | | - |
142 | | -# Now that we have all of the sales of that part in the specified region, we need |
143 | | -# to determine which of those came from suppliers in the nation we are interested in. |
144 | | - |
145 | | -df_national_suppliers = ( |
146 | | - df_nation.filter(col("n_name") == supplier_nation) |
147 | | - .join(df_supplier, left_on="n_nationkey", right_on="s_nationkey") |
148 | | - .select("s_suppkey") |
149 | | -) |
150 | | - |
151 | | - |
152 | | -# Part 3: Combine suppliers and customers and compute the market share |
153 | | - |
154 | | -# Left-outer join the national suppliers onto the regional sales. Rows from |
155 | | -# other suppliers get a NULL ``s_suppkey``, which the CASE expression uses |
156 | | -# to zero out the non-national volume. |
157 | | - |
158 | | -df = df_regional_customers.join( |
159 | | - df_national_suppliers, left_on="l_suppkey", right_on="s_suppkey", how="left" |
160 | | -).with_columns( |
161 | | - national_volume=F.when(col("s_suppkey").is_not_null(), col("volume")).otherwise( |
162 | | - lit(0.0) |
163 | | - ), |
164 | | - o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), |
| 137 | + .join(df_supplier_with_nation, left_on="l_suppkey", right_on="s_suppkey") |
| 138 | + .with_columns( |
| 139 | + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), |
| 140 | + o_year=F.datepart(lit("year"), col("o_orderdate")).cast(pa.int32()), |
| 141 | + ) |
165 | 142 | ) |
166 | 143 |
|
167 | | - |
168 | | -# Aggregate, compute the share, and sort. |
169 | | - |
| 144 | +# Aggregate the total and national volumes per year via the ``filter`` |
| 145 | +# kwarg on ``F.sum`` (DataFrame form of SQL ``sum(... ) FILTER (WHERE ...)``). |
| 146 | +# ``coalesce`` handles the case where no sale came from the target nation |
| 147 | +# for a given year. |
170 | 148 | df = ( |
171 | 149 | df.aggregate( |
172 | 150 | ["o_year"], |
173 | 151 | [ |
174 | | - F.sum(col("volume")).alias("volume"), |
175 | | - F.sum(col("national_volume")).alias("national_volume"), |
| 152 | + F.sum(col("volume"), filter=col("supp_nation") == supplier_nation).alias( |
| 153 | + "national_volume" |
| 154 | + ), |
| 155 | + F.sum(col("volume")).alias("total_volume"), |
176 | 156 | ], |
177 | 157 | ) |
178 | | - .select("o_year", (col("national_volume") / col("volume")).alias("mkt_share")) |
| 158 | + .select( |
| 159 | + "o_year", |
| 160 | + (F.coalesce(col("national_volume"), lit(0.0)) / col("total_volume")).alias( |
| 161 | + "mkt_share" |
| 162 | + ), |
| 163 | + ) |
179 | 164 | .sort_by("o_year") |
180 | 165 | ) |
181 | 166 |
|
|
0 commit comments