Skip to content

Commit 7251258

Browse files
committed
Create typed DatasetFilters object
Why these changes are being introduced: Currently, both loading and filtering the dataset resulted in a large number of keyword arguments for each method to aligned with dataset columns. Moving into read methods, which will also support filtering, this was a substantial amount of duplication and could be error prone over time. How this addresses that need: * Creates a typed dictionary DatasetFilters that includes all columns or partitions that we can use when filtering the dataset * Each method that can filter the dataset accepts kwargs, but they are typed to this typed dictionary Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-417
1 parent 41b70e3 commit 7251258

1 file changed

Lines changed: 31 additions & 112 deletions

File tree

timdex_dataset_api/dataset.py

Lines changed: 31 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Iterator
88
from datetime import UTC, date, datetime
99
from functools import reduce
10-
from typing import TYPE_CHECKING
10+
from typing import TYPE_CHECKING, TypedDict, Unpack
1111

1212
import boto3
1313
import pyarrow as pa
@@ -45,18 +45,17 @@
4545
"day",
4646
]
4747

48-
# order must match assignment in TIMDEXDataset._get_filtered_dataset
49-
TIMDEX_DATASET_FILTER_COLUMNS = (
50-
"timdex_record_id",
51-
"source",
52-
"run_date",
53-
"run_type",
54-
"run_id",
55-
"action",
56-
"year",
57-
"month",
58-
"day",
59-
)
48+
49+
class DatasetFilters(TypedDict, total=False):
50+
timdex_record_id: str | None
51+
source: str | None
52+
run_date: str | date | None
53+
run_type: str | None
54+
run_id: str | None
55+
action: str | None
56+
year: str | None
57+
month: str | None
58+
day: str | None
6059

6160

6261
DEFAULT_BATCH_SIZE = 1_000
@@ -93,21 +92,11 @@ def row_count(self) -> int:
9392

9493
def load(
9594
self,
96-
*,
97-
timdex_record_id: str | None = None,
98-
source: str | None = None,
99-
run_date: str | date | None = None,
100-
run_type: str | None = None,
101-
run_id: str | None = None,
102-
action: str | None = None,
103-
year: str | None = None,
104-
month: str | None = None,
105-
day: str | None = None,
95+
**filters: Unpack[DatasetFilters],
10696
) -> None:
107-
"""Lazy load a pyarrow.dataset.Dataset to a TIMDEXDataset.
97+
"""Lazy load a pyarrow.dataset.Dataset and set to self.dataset.
10898
109-
This method sets a pyarrow.dataset.Dataset to the TIMDEXDataset.dataset
110-
attribute. Loading comprises of two main steps:
99+
Loading is comprised of two main steps:
111100
112101
- load: Lazily load full dataset. PyArrow will "discover" full dataset.
113102
Note: This step may take a couple of seconds but leans on PyArrow's
@@ -120,31 +109,13 @@ def load(
120109
raised when reading or writing data.
121110
122111
Args:
123-
All args are optional and must be passed in as keyword args.
124-
125-
Partition columns:
126-
- run_date (str | date | None, optional)
127-
- year (str | None, optional)
128-
- month (str | None, optional)
129-
- day (str | None, optional)
130-
131-
If 'run_date' is provided, partition filters are derived by parsing
132-
a datetime.date object from the 'run_date' value and extracting
133-
ymd values to use in filter expression.
134-
135-
Non-partition columns
136-
- timdex_record_id (str | None, optional)
137-
- source (str | None, optional)
138-
- run_type (str | None, optional)
139-
- run_id (str | None, optional)
140-
- action (str | None, optional)
141-
142-
Any values specified for each column will be used to filter
143-
the dataset to rows matching column-value pairs.
112+
- filters: kwargs typed via DatasetFilters TypedDict
113+
- Filters passed directly in method call, e.g. source="alma",
114+
run_date="2024-12-20", etc., but are typed according to DatasetFilters.
144115
"""
145116
start_time = time.perf_counter()
146117

147-
# lazy load full dataset
118+
# load dataset
148119
self.dataset = ds.dataset(
149120
self.source,
150121
schema=self.schema,
@@ -154,17 +125,7 @@ def load(
154125
)
155126

156127
# filter dataset
157-
self.dataset = self._get_filtered_dataset(
158-
timdex_record_id=timdex_record_id,
159-
source=source,
160-
run_date=run_date,
161-
run_type=run_type,
162-
run_id=run_id,
163-
action=action,
164-
year=year,
165-
month=month,
166-
day=day,
167-
)
128+
self.dataset = self._get_filtered_dataset(**filters)
168129

169130
logger.info(
170131
f"Dataset successfully loaded: '{self.location}', "
@@ -173,41 +134,18 @@ def load(
173134

174135
def _get_filtered_dataset(
175136
self,
176-
*,
177-
timdex_record_id: str | None = None,
178-
source: str | None = None,
179-
run_date: str | date | None = None,
180-
run_type: str | None = None,
181-
run_id: str | None = None,
182-
action: str | None = None,
183-
year: str | None = None,
184-
month: str | None = None,
185-
day: str | None = None,
137+
**filters: Unpack[DatasetFilters],
186138
) -> ds.Dataset:
187-
"""Lazy filter a pyarrow.dataset.Dataset on a TIMDEXDataset.
139+
"""Lazy filter self.dataset and return a new pyarrow Dataset object.
188140
189141
This method will construct a single pyarrow.compute.Expression
190142
that is combined from individual equality comparison predicates
191143
using the provided filters.
192144
193145
Args:
194-
All args are optional and must be passed in as keyword args.
195-
196-
Run date columns
197-
- run_date (str | date | None, optional)
198-
- year (str | None, optional)
199-
- month (str | None, optional)
200-
- day (str | None, optional)
201-
202-
If 'run_date' is provided, the 'year', 'month', and 'day' values
203-
are parsed from 'run_date'.
204-
205-
Other columns:
206-
- timdex_record_id (str | None, optional)
207-
- source (str | None, optional)
208-
- run_type (str | None, optional)
209-
- run_id (str | None, optional)
210-
- action (str | None, optional)
146+
- filters: kwargs typed via DatasetFilters TypedDict
147+
- Filters passed directly in method call, e.g. source="alma",
148+
run_date="2024-12-20", etc., but are typed according to DatasetFilters.
211149
212150
Raises:
213151
DatasetNotLoadedError: Raised if `self.dataset` is None.
@@ -222,32 +160,13 @@ def _get_filtered_dataset(
222160
if not self.dataset:
223161
raise DatasetNotLoadedError
224162

225-
# instantiate filters dict
226-
filters_dict = dict(
227-
zip(
228-
TIMDEX_DATASET_FILTER_COLUMNS,
229-
[
230-
timdex_record_id,
231-
source,
232-
run_date,
233-
run_type,
234-
run_id,
235-
action,
236-
year,
237-
month,
238-
day,
239-
],
240-
strict=False,
241-
),
242-
)
243-
244-
# get filters for partition columns ('run_date' or 'run_date' components)
245-
if run_date:
246-
filters_dict.update(self._parse_date_filters(run_date))
163+
# if run_date provided, derive year, month, and day partition filters and set
164+
if filters.get("run_date"):
165+
filters.update(self._parse_date_filters(filters["run_date"]))
247166

248167
# create filter expressions for element-wise equality comparisons
249168
expressions = []
250-
for field, value in filters_dict.items():
169+
for field, value in filters.items():
251170
if value:
252171
expressions.append(pc.equal(pc.field(field), value))
253172

@@ -264,7 +183,7 @@ def _get_filtered_dataset(
264183

265184
return self.dataset.filter(combined_expressions)
266185

267-
def _parse_date_filters(self, run_date: str | date | None) -> dict:
186+
def _parse_date_filters(self, run_date: str | date | None) -> DatasetFilters:
268187
"""Parse date filters from 'run_date'.
269188
270189
Args:
@@ -278,7 +197,7 @@ def _parse_date_filters(self, run_date: str | date | None) -> dict:
278197
from a provided 'run_date' str.
279198
280199
Returns:
281-
dict: 'run_date' filters.
200+
DatasetFilters[dict]: values for run_date, year, month, and day
282201
"""
283202
if isinstance(run_date, str):
284203
run_date_obj = strict_date_parse(run_date)

0 commit comments

Comments
 (0)