77from collections .abc import Iterator
88from datetime import UTC , date , datetime
99from functools import reduce
10- from typing import TYPE_CHECKING
10+ from typing import TYPE_CHECKING , TypedDict , Unpack
1111
1212import boto3
1313import pyarrow as pa
4545 "day" ,
4646]
4747
48- # order must match assignment in TIMDEXDataset._get_filtered_dataset
49- TIMDEX_DATASET_FILTER_COLUMNS = (
50- "timdex_record_id" ,
51- "source" ,
52- "run_date" ,
53- "run_type" ,
54- "run_id" ,
55- "action" ,
56- "year" ,
57- "month" ,
58- "day" ,
59- )
48+
49+ class DatasetFilters (TypedDict , total = False ):
50+ timdex_record_id : str | None
51+ source : str | None
52+ run_date : str | date | None
53+ run_type : str | None
54+ run_id : str | None
55+ action : str | None
56+ year : str | None
57+ month : str | None
58+ day : str | None
6059
6160
6261DEFAULT_BATCH_SIZE = 1_000
@@ -93,21 +92,11 @@ def row_count(self) -> int:
9392
9493 def load (
9594 self ,
96- * ,
97- timdex_record_id : str | None = None ,
98- source : str | None = None ,
99- run_date : str | date | None = None ,
100- run_type : str | None = None ,
101- run_id : str | None = None ,
102- action : str | None = None ,
103- year : str | None = None ,
104- month : str | None = None ,
105- day : str | None = None ,
95+ ** filters : Unpack [DatasetFilters ],
10696 ) -> None :
107- """Lazy load a pyarrow.dataset.Dataset to a TIMDEXDataset .
97+ """Lazy load a pyarrow.dataset.Dataset and set to self.dataset .
10898
109- This method sets a pyarrow.dataset.Dataset to the TIMDEXDataset.dataset
110- attribute. Loading comprises of two main steps:
99+ Loading is comprised of two main steps:
111100
112101 - load: Lazily load full dataset. PyArrow will "discover" full dataset.
113102 Note: This step may take a couple of seconds but leans on PyArrow's
@@ -120,31 +109,13 @@ def load(
120109 raised when reading or writing data.
121110
122111 Args:
123- All args are optional and must be passed in as keyword args.
124-
125- Partition columns:
126- - run_date (str | date | None, optional)
127- - year (str | None, optional)
128- - month (str | None, optional)
129- - day (str | None, optional)
130-
131- If 'run_date' is provided, partition filters are derived by parsing
132- a datetime.date object from the 'run_date' value and extracting
133- ymd values to use in filter expression.
134-
135- Non-partition columns
136- - timdex_record_id (str | None, optional)
137- - source (str | None, optional)
138- - run_type (str | None, optional)
139- - run_id (str | None, optional)
140- - action (str | None, optional)
141-
142- Any values specified for each column will be used to filter
143- the dataset to rows matching column-value pairs.
112+ - filters: kwargs typed via DatasetFilters TypedDict
113+ - Filters passed directly in method call, e.g. source="alma",
114+ run_date="2024-12-20", etc., but are typed according to DatasetFilters.
144115 """
145116 start_time = time .perf_counter ()
146117
147- # lazy load full dataset
118+ # load dataset
148119 self .dataset = ds .dataset (
149120 self .source ,
150121 schema = self .schema ,
@@ -154,17 +125,7 @@ def load(
154125 )
155126
156127 # filter dataset
157- self .dataset = self ._get_filtered_dataset (
158- timdex_record_id = timdex_record_id ,
159- source = source ,
160- run_date = run_date ,
161- run_type = run_type ,
162- run_id = run_id ,
163- action = action ,
164- year = year ,
165- month = month ,
166- day = day ,
167- )
128+ self .dataset = self ._get_filtered_dataset (** filters )
168129
169130 logger .info (
170131 f"Dataset successfully loaded: '{ self .location } ', "
@@ -173,41 +134,18 @@ def load(
173134
174135 def _get_filtered_dataset (
175136 self ,
176- * ,
177- timdex_record_id : str | None = None ,
178- source : str | None = None ,
179- run_date : str | date | None = None ,
180- run_type : str | None = None ,
181- run_id : str | None = None ,
182- action : str | None = None ,
183- year : str | None = None ,
184- month : str | None = None ,
185- day : str | None = None ,
137+ ** filters : Unpack [DatasetFilters ],
186138 ) -> ds .Dataset :
187- """Lazy filter a pyarrow .dataset.Dataset on a TIMDEXDataset .
139+ """Lazy filter self .dataset and return a new pyarrow Dataset object .
188140
189141 This method will construct a single pyarrow.compute.Expression
190142 that is combined from individual equality comparison predicates
191143 using the provided filters.
192144
193145 Args:
194- All args are optional and must be passed in as keyword args.
195-
196- Run date columns
197- - run_date (str | date | None, optional)
198- - year (str | None, optional)
199- - month (str | None, optional)
200- - day (str | None, optional)
201-
202- If 'run_date' is provided, the 'year', 'month', and 'day' values
203- are parsed from 'run_date'.
204-
205- Other columns:
206- - timdex_record_id (str | None, optional)
207- - source (str | None, optional)
208- - run_type (str | None, optional)
209- - run_id (str | None, optional)
210- - action (str | None, optional)
146+ - filters: kwargs typed via DatasetFilters TypedDict
147+ - Filters passed directly in method call, e.g. source="alma",
148+ run_date="2024-12-20", etc., but are typed according to DatasetFilters.
211149
212150 Raises:
213151 DatasetNotLoadedError: Raised if `self.dataset` is None.
@@ -222,32 +160,13 @@ def _get_filtered_dataset(
222160 if not self .dataset :
223161 raise DatasetNotLoadedError
224162
225- # instantiate filters dict
226- filters_dict = dict (
227- zip (
228- TIMDEX_DATASET_FILTER_COLUMNS ,
229- [
230- timdex_record_id ,
231- source ,
232- run_date ,
233- run_type ,
234- run_id ,
235- action ,
236- year ,
237- month ,
238- day ,
239- ],
240- strict = False ,
241- ),
242- )
243-
244- # get filters for partition columns ('run_date' or 'run_date' components)
245- if run_date :
246- filters_dict .update (self ._parse_date_filters (run_date ))
163+ # if run_date provided, derive year, month, and day partition filters and set
164+ if filters .get ("run_date" ):
165+ filters .update (self ._parse_date_filters (filters ["run_date" ]))
247166
248167 # create filter expressions for element-wise equality comparisons
249168 expressions = []
250- for field , value in filters_dict .items ():
169+ for field , value in filters .items ():
251170 if value :
252171 expressions .append (pc .equal (pc .field (field ), value ))
253172
@@ -264,7 +183,7 @@ def _get_filtered_dataset(
264183
265184 return self .dataset .filter (combined_expressions )
266185
267- def _parse_date_filters (self , run_date : str | date | None ) -> dict :
186+ def _parse_date_filters (self , run_date : str | date | None ) -> DatasetFilters :
268187 """Parse date filters from 'run_date'.
269188
270189 Args:
@@ -278,7 +197,7 @@ def _parse_date_filters(self, run_date: str | date | None) -> dict:
278197 from a provided 'run_date' str.
279198
280199 Returns:
281- dict: ' run_date' filters.
200+ DatasetFilters[ dict]: values for run_date, year, month, and day
282201 """
283202 if isinstance (run_date , str ):
284203 run_date_obj = strict_date_parse (run_date )
0 commit comments