|
111 | 111 |
|
112 | 112 |
|
113 | 113 | # Filter to time of interest |
114 | | -df_lineitem = df_lineitem.filter(col("l_shipdate") >= start_date).filter( |
115 | | - col("l_shipdate") <= end_date |
| 114 | +df_lineitem = df_lineitem.filter( |
| 115 | + col("l_shipdate") >= start_date, col("l_shipdate") <= end_date |
116 | 116 | ) |
117 | 117 |
|
118 | 118 |
|
|
122 | 122 |
|
123 | 123 | # Limit suppliers to either nation |
124 | 124 | df_supplier = df_supplier.join( |
125 | | - df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" |
126 | | -).select(col("s_suppkey"), col("n_name").alias("supp_nation")) |
| 125 | + df_nation, left_on="s_nationkey", right_on="n_nationkey" |
| 126 | +).select("s_suppkey", col("n_name").alias("supp_nation")) |
127 | 127 |
|
128 | 128 | # Limit customers to either nation |
129 | 129 | df_customer = df_customer.join( |
130 | | - df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" |
131 | | -).select(col("c_custkey"), col("n_name").alias("cust_nation")) |
| 130 | + df_nation, left_on="c_nationkey", right_on="n_nationkey" |
| 131 | +).select("c_custkey", col("n_name").alias("cust_nation")) |
132 | 132 |
|
133 | 133 | # Join up all the data frames from line items, and make sure the supplier and customer are in |
134 | 134 | # different nations. |
135 | 135 | df = ( |
136 | | - df_lineitem.join( |
137 | | - df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" |
138 | | - ) |
139 | | - .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") |
140 | | - .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") |
| 136 | + df_lineitem.join(df_orders, left_on="l_orderkey", right_on="o_orderkey") |
| 137 | + .join(df_customer, left_on="o_custkey", right_on="c_custkey") |
| 138 | + .join(df_supplier, left_on="l_suppkey", right_on="s_suppkey") |
141 | 139 | .filter(col("cust_nation") != col("supp_nation")) |
142 | 140 | ) |
143 | 141 |
|
144 | 142 | # Extract out two values for every line item |
145 | | -df = df.with_column( |
146 | | - "l_year", F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()) |
147 | | -).with_column("volume", col("l_extendedprice") * (lit(1.0) - col("l_discount"))) |
| 143 | +df = df.with_columns( |
| 144 | + l_year=F.datepart(lit("year"), col("l_shipdate")).cast(pa.int32()), |
| 145 | + volume=col("l_extendedprice") * (lit(1.0) - col("l_discount")), |
| 146 | +) |
148 | 147 |
|
149 | | -# Aggregate the results |
| 148 | +# Aggregate and sort per the spec. |
150 | 149 | df = df.aggregate( |
151 | | - [col("supp_nation"), col("cust_nation"), col("l_year")], |
| 150 | + ["supp_nation", "cust_nation", "l_year"], |
152 | 151 | [F.sum(col("volume")).alias("revenue")], |
153 | | -) |
154 | | - |
155 | | -# Sort based on problem statement requirements |
156 | | -df = df.sort(col("supp_nation").sort(), col("cust_nation").sort(), col("l_year").sort()) |
| 152 | +).sort_by("supp_nation", "cust_nation", "l_year") |
157 | 153 |
|
158 | 154 | df.show() |
0 commit comments