Skip to content

Commit 3e5e707

Browse files
Remove option to load dataset using partition prefix
Why these changes are being introduced: * When a partition prefix is constructed for the dataset with year, month, and day (or a combination of them), and appended to the 'source' arg in ds.dataset, the resulting dataset will contain 'None' values for any partition columns used in the prefix. It is then problematic with the "post-load" filtering step, which again attempts to filter on those partition columns. It was believed that would be somewhat inconsequential, but additional testing revealed this was not the case. These changes simplify the TIMDEXDataset.load method by instead relying on PyArrow's efficient dataset discovery and reading processes. For more details, see comment on PR #31: #31 (review). How this addresses that need: * Remove '_get_partition_prefixes' private method * Update '_parse_date_filters' to raise TypeError * Update unit tests Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-425
1 parent 3daca7b commit 3e5e707

3 files changed

Lines changed: 80 additions & 161 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ classifiers = [
2222
]
2323

2424
dependencies = [
25+
"attrs",
2526
"boto3",
2627
"duckdb",
2728
"pandas",

tests/test_dataset.py

Lines changed: 41 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_dataset_load_local_sets_filesystem_and_dataset_success(
3535
result = timdex_dataset.load()
3636

3737
mock_pyarrow_ds.assert_called_once_with(
38-
"local/path/to/dataset/",
38+
"local/path/to/dataset",
3939
schema=timdex_dataset.schema,
4040
format="parquet",
4141
partitioning="hive",
@@ -59,7 +59,7 @@ def test_dataset_load_s3_sets_filesystem_and_dataset_success(
5959

6060
mock_get_s3_fs.assert_called_once()
6161
mock_pyarrow_ds.assert_called_once_with(
62-
"bucket/path/to/dataset/",
62+
"bucket/path/to/dataset",
6363
schema=timdex_dataset.schema,
6464
format="parquet",
6565
partitioning="hive",
@@ -69,60 +69,55 @@ def test_dataset_load_s3_sets_filesystem_and_dataset_success(
6969
assert result is None
7070

7171

72-
@patch("timdex_dataset_api.dataset.fs.LocalFileSystem")
73-
@patch("timdex_dataset_api.dataset.ds.dataset")
74-
def test_dataset_load_with_partition_prefix_via_run_date_success(
75-
mock_pyarrow_ds, mock_local_fs
76-
):
77-
mock_local_fs.return_value = MagicMock()
78-
mock_pyarrow_ds.return_value = MagicMock()
72+
def test_dataset_load_without_filters_success(fixed_local_dataset):
73+
fixed_local_dataset.load()
7974

80-
timdex_dataset = TIMDEXDataset(location="local/path/to/dataset")
81-
timdex_dataset.load(run_date="2024-12-01")
75+
assert os.path.exists(fixed_local_dataset.location)
76+
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
8277

83-
mock_pyarrow_ds.assert_called_once_with(
84-
"local/path/to/dataset/year=2024/month=12/day=01",
85-
schema=timdex_dataset.schema,
86-
format="parquet",
87-
partitioning="hive",
88-
filesystem=mock_local_fs.return_value,
89-
)
9078

79+
def test_dataset_load_with_run_date_str_filters_success(fixed_local_dataset):
80+
fixed_local_dataset.load(run_date="2024-12-01")
9181

92-
@patch("timdex_dataset_api.dataset.fs.LocalFileSystem")
93-
@patch("timdex_dataset_api.dataset.ds.dataset")
94-
def test_dataset_load_with_partition_prefix_via_run_date_components_success(
95-
mock_pyarrow_ds, mock_local_fs
96-
):
97-
mock_local_fs.return_value = MagicMock()
98-
mock_pyarrow_ds.return_value = MagicMock()
82+
assert os.path.exists(fixed_local_dataset.location)
83+
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
9984

100-
timdex_dataset = TIMDEXDataset(location="local/path/to/dataset")
101-
timdex_dataset.load(year="2024")
10285

103-
mock_pyarrow_ds.assert_called_once_with(
104-
"local/path/to/dataset/year=2024",
105-
schema=timdex_dataset.schema,
106-
format="parquet",
107-
partitioning="hive",
108-
filesystem=mock_local_fs.return_value,
109-
)
86+
def test_dataset_load_with_run_date_obj_filters_success(fixed_local_dataset):
87+
fixed_local_dataset.load(run_date=date(2024, 12, 1))
88+
89+
assert os.path.exists(fixed_local_dataset.location)
90+
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
11091

11192

112-
def test_dataset_load_no_filters_success(fixed_local_dataset):
113-
fixed_local_dataset.load()
93+
def test_dataset_load_with_ymd_filters_success(fixed_local_dataset):
94+
fixed_local_dataset.load(year="2024", month="12", day="01")
11495

11596
assert os.path.exists(fixed_local_dataset.location)
11697
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
11798

11899

119-
def test_dataset_load_and_filter_by_non_partition_field_success(fixed_local_dataset):
100+
def test_dataset_load_with_single_nonpartition_filters_success(fixed_local_dataset):
120101
fixed_local_dataset.load(timdex_record_id="alma:0")
121102

122103
assert fixed_local_dataset.row_count == 1
123104

124105

125-
def test_dataset_get_filtered_dataset_by_all_fields_success(fixed_local_dataset):
106+
def test_dataset_load_with_multi_nonpartition_filters_success(fixed_local_dataset):
107+
fixed_local_dataset.load(
108+
timdex_record_id="alma:0",
109+
source="alma",
110+
run_type="daily",
111+
run_id="abc123",
112+
action="index",
113+
)
114+
115+
assert fixed_local_dataset.row_count == 1
116+
117+
118+
def test_dataset_get_filtered_dataset_with_multi_nonpartition_filters_success(
119+
fixed_local_dataset,
120+
):
126121
fixed_local_dataset.load() # initial load dataset, no filters passed
127122

128123
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(
@@ -138,7 +133,9 @@ def test_dataset_get_filtered_dataset_by_all_fields_success(fixed_local_dataset)
138133
assert filtered_local_df["timdex_record_id"].iloc[0] == "alma:0"
139134

140135

141-
def test_dataset_get_filtered_dataset_by_single_fields_success(fixed_local_dataset):
136+
def test_dataset_get_filtered_dataset_with_single_nonpartition_success(
137+
fixed_local_dataset,
138+
):
142139
fixed_local_dataset.load() # initial load dataset, no filters passed
143140

144141
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(
@@ -152,7 +149,7 @@ def test_dataset_get_filtered_dataset_by_single_fields_success(fixed_local_datas
152149
assert filtered_local_df["run_id"].unique() == ["abc123"]
153150

154151

155-
def test_dataset_get_filtered_dataset_by_run_date_str_successs(fixed_local_dataset):
152+
def test_dataset_get_filtered_dataset_with_run_date_str_successs(fixed_local_dataset):
156153
fixed_local_dataset.load() # initial load dataset, no filters passed
157154

158155
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(
@@ -166,7 +163,7 @@ def test_dataset_get_filtered_dataset_by_run_date_str_successs(fixed_local_datas
166163
assert empty_local_dataset.count_rows() == 0
167164

168165

169-
def test_dataset_get_filtered_dataset_by_run_date_date_success(fixed_local_dataset):
166+
def test_dataset_get_filtered_dataset_with_run_date_obj_success(fixed_local_dataset):
170167
fixed_local_dataset.load() # initial load dataset, no filters passed
171168

172169
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(
@@ -182,7 +179,7 @@ def test_dataset_get_filtered_dataset_by_run_date_date_success(fixed_local_datas
182179
assert empty_local_dataset.count_rows() == 0
183180

184181

185-
def test_dataset_get_filtered_dataset_by_run_date_components_success(fixed_local_dataset):
182+
def test_dataset_get_filtered_dataset_with_ymd_success(fixed_local_dataset):
186183
fixed_local_dataset.load() # initial load dataset, no filters passed
187184

188185
filtered_local_dataset = fixed_local_dataset._get_filtered_dataset(year="2024")
@@ -194,13 +191,13 @@ def test_dataset_get_filtered_dataset_by_run_date_components_success(fixed_local
194191
assert empty_local_dataset.count_rows() == 0
195192

196193

197-
def test_dataset_get_filtered_dataset_by_run_date_if_invalid_type_raise_error(
194+
def test_dataset_get_filtered_dataset_with_run_date_invalid_raise_error(
198195
fixed_local_dataset,
199196
):
200197
fixed_local_dataset.load() # initial load dataset, no filters passed
201198

202199
with pytest.raises(
203-
ValueError,
200+
TypeError,
204201
match=(
205202
"Provided 'run_date' value must be a string matching format '%Y-%m-%d' "
206203
"or a datetime.date."
@@ -209,36 +206,6 @@ def test_dataset_get_filtered_dataset_by_run_date_if_invalid_type_raise_error(
209206
_ = fixed_local_dataset._get_filtered_dataset(run_date=999)
210207

211208

212-
def test_dataset_get_partition_prefixes_with_run_date_success():
213-
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
214-
215-
assert (
216-
timdex_dataset._get_partition_prefixes(run_date="2024-12-01")
217-
== "year=2024/month=12/day=01"
218-
)
219-
220-
221-
def test_dataset_get_partition_prefixes_without_run_date_success():
222-
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
223-
224-
assert (
225-
timdex_dataset._get_partition_prefixes(year="2024", month="12", day="01")
226-
) == "year=2024/month=12/day=01"
227-
assert (
228-
timdex_dataset._get_partition_prefixes(year="2024", month="12")
229-
== "year=2024/month=12"
230-
)
231-
assert timdex_dataset._get_partition_prefixes(year="2024") == "year=2024"
232-
233-
234-
def test_dataset_get_partition_prefixes_without_run_date_raise_error():
235-
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
236-
with pytest.raises(
237-
ValueError, match="Insufficient arguments to construct a valid partition prefix."
238-
):
239-
assert timdex_dataset._get_partition_prefixes(month="12", day="01")
240-
241-
242209
def test_dataset_get_s3_filesystem_success(mocker):
243210
mocked_s3_filesystem = mocker.spy(fs, "S3FileSystem")
244211
s3_filesystem = TIMDEXDataset.get_s3_filesystem()

timdex_dataset_api/dataset.py

Lines changed: 38 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import itertools
44
import operator
5-
import os
65
import time
76
import uuid
87
from collections.abc import Iterator
@@ -110,9 +109,11 @@ def load(
110109
This method sets a pyarrow.dataset.Dataset to the TIMDEXDataset.dataset
111110
attribute. Loading comprises of two main steps:
112111
113-
- pre load: Append a partition prefix to self.source using either 'run_date'
114-
or 'run_date' components to skip reading unecessary data partitions.
115-
- post load: Lazily filter TIMDEXDataset.
112+
- load: Lazily load full dataset. PyArrow will "discover" full dataset.
113+
Note: This step may take a couple of seconds but leans on PyArrow's
114+
parquet reading processes.
115+
- filter: Lazily filter rows in the PyArrow dataset by conditions on
116+
TIMDEX_DATASET_FILTER_COLUMNS.
116117
117118
The dataset is loaded via the expected schema as defined by module constant
118119
TIMDEX_DATASET_SCHEMA. If the target dataset differs in any way, errors may be
@@ -127,14 +128,9 @@ def load(
127128
- month (str | None, optional)
128129
- day (str | None, optional)
129130
130-
If 'run_date' is provided, partition prefixes are derived by parsing
131-
'run_date' into its individual components.
132-
- 'run_date' str values must match date format "%Y-%m-%d".
133-
134-
If 'run_date' is not provided, partition prefixes are derived using
135-
provided values for individual 'run_date' components: year, month, day.
136-
See TIMDEXDataset.get_partition_prefix to see accepted combination
137-
of args.
131+
If 'run_date' is provided, partition filters are derived by parsing
132+
a datetime.date object from the 'run_date' value and extracting
133+
ymd values to use in filter expression.
138134
139135
Non-partition columns
140136
- timdex_record_id (str | None, optional)
@@ -148,22 +144,16 @@ def load(
148144
"""
149145
start_time = time.perf_counter()
150146

151-
source_path = self.source
152-
if isinstance(self.source, str):
153-
source_path = os.path.join(
154-
self.source, self._get_partition_prefixes(run_date, year, month, day)
155-
)
156-
157-
# pre load: load dataset lazily, with an optional partition prefix
147+
# lazy load full dataset
158148
self.dataset = ds.dataset(
159-
source_path,
149+
self.source,
160150
schema=self.schema,
161151
format="parquet",
162152
partitioning="hive",
163153
filesystem=self.filesystem,
164154
)
165155

166-
# post load: filter dataset
156+
# filter dataset
167157
self.dataset = self._get_filtered_dataset(
168158
timdex_record_id=timdex_record_id,
169159
source=source,
@@ -252,7 +242,8 @@ def _get_filtered_dataset(
252242
)
253243

254244
# get filters for partition columns ('run_date' or 'run_date' components)
255-
filters_dict.update(self._parse_date_filters(run_date))
245+
if run_date:
246+
filters_dict.update(self._parse_date_filters(run_date))
256247

257248
# create filter expressions for element-wise equality comparisons
258249
expressions = []
@@ -273,78 +264,38 @@ def _get_filtered_dataset(
273264

274265
return self.dataset.filter(combined_expressions)
275266

276-
def _get_partition_prefixes(
277-
self,
278-
run_date: str | date | None = None,
279-
year: str | None = None,
280-
month: str | None = None,
281-
day: str | None = None,
282-
) -> str:
283-
"""Derive partition prefixes from provided 'run_date' or 'run_date' components.
284-
285-
Argument 'run_date' is a date string formatted as "YYYY-MM-DD". If not provided,
286-
the arguments 'year', 'month', and 'date' (also string values) must be provided
287-
in specific combinations:
267+
def _parse_date_filters(self, run_date: str | date | None) -> dict:
268+
"""Parse date filters from 'run_date'.
288269
289-
- year, month, day
290-
- year, month
291-
- year
270+
Args:
271+
run_date (str | date | None): If str, the value must match the
272+
date format "%Y-%m-%d"; if date, ymd values are extracted
273+
as str.
292274
293-
Any other combinations are insufficient to construct a valid partition prefix.
275+
Raises:
276+
TypeError: Raised when 'run_date' is an invalid type.
277+
ValueError: Raised when either a datetime.date object cannot be parsed
278+
from a provided 'run_date' str.
294279
295-
Returns a string of partition prefixes: "year=2024/month=12/day=01".
280+
Returns:
281+
dict: 'run_date' filters.
296282
"""
297-
if run_date:
298-
run_date_filters = self._parse_date_filters(run_date)
299-
return (
300-
f"year={run_date_filters["year"]}/"
301-
f"month={run_date_filters["month"]}/"
302-
f"day={run_date_filters["day"]}"
303-
)
304-
305-
partition_prefixes = []
306-
if year and month and day:
307-
partition_prefixes.extend([year, month, day])
308-
elif year and month and day is None:
309-
partition_prefixes.extend([year, month])
310-
elif year and month is None and day is None:
311-
partition_prefixes.extend([year])
312-
elif year is None and month is None and day is None:
313-
return ""
283+
if isinstance(run_date, str):
284+
run_date_obj = strict_date_parse(run_date)
285+
elif isinstance(run_date, date):
286+
run_date_obj = run_date
314287
else:
315-
raise ValueError(
316-
"Insufficient arguments to construct a valid partition prefix."
288+
raise TypeError(
289+
"Provided 'run_date' value must be a string matching format "
290+
"'%Y-%m-%d' or a datetime.date."
317291
)
318292

319-
partition_prefixes_dict = dict(
320-
zip(TIMDEX_DATASET_PARTITION_COLUMNS, partition_prefixes, strict=False),
321-
)
322-
return "/".join(
323-
f"{partition_column}={partition_value}"
324-
for partition_column, partition_value in partition_prefixes_dict.items()
325-
)
326-
327-
def _parse_date_filters(self, run_date: str | date | None) -> dict:
328-
date_filters = {}
329-
if run_date is not None:
330-
if isinstance(run_date, str):
331-
run_date_obj = strict_date_parse(run_date)
332-
elif isinstance(run_date, date):
333-
run_date_obj = run_date
334-
else:
335-
raise ValueError(
336-
"Provided 'run_date' value must be a string matching format "
337-
"'%Y-%m-%d' or a datetime.date."
338-
)
339-
date_filters.update(
340-
{
341-
"run_date": run_date_obj,
342-
"year": run_date_obj.strftime("%Y"),
343-
"month": run_date_obj.strftime("%m"),
344-
"day": run_date_obj.strftime("%d"),
345-
}
346-
)
347-
return date_filters
293+
return {
294+
"run_date": run_date_obj,
295+
"year": run_date_obj.strftime("%Y"),
296+
"month": run_date_obj.strftime("%m"),
297+
"day": run_date_obj.strftime("%d"),
298+
}
348299

349300
@staticmethod
350301
def get_s3_filesystem() -> fs.FileSystem:

0 commit comments

Comments
 (0)