Skip to content

Commit 614d6fa

Browse files
committed
Refactor to configuration object
Why these changes are being introduced: With the addition of two read configurations that would be passed around beteween multiple methods, the dataset module tipped to where a centralized configuration object would be helpful. Additionally, we have learned that per-operation configurations are rare, and much more likely to be set once during TIMDEXDataset init, or even as env vars for the duration of the library import. How this addresses that need: Creates a dataclass TIMDEXDatasetConfig that is passed to TIMDEXDataset on init. This class provides a typed object, with sensible defaults, that are shared throughout all read and write methods. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-468
1 parent ad3f6bd commit 614d6fa

3 files changed

Lines changed: 88 additions & 69 deletions

File tree

tests/test_dataset.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ruff: noqa: S105, S106, SLF001
1+
# ruff: noqa: S105, S106, SLF001, PLR2004
22
import os
33
from datetime import date
44
from unittest.mock import MagicMock, patch
@@ -7,7 +7,11 @@
77
import pytest
88
from pyarrow import fs
99

10-
from timdex_dataset_api.dataset import DatasetNotLoadedError, TIMDEXDataset
10+
from timdex_dataset_api.dataset import (
11+
DatasetNotLoadedError,
12+
TIMDEXDataset,
13+
TIMDEXDatasetConfig,
14+
)
1115

1216

1317
@pytest.mark.parametrize(
@@ -23,6 +27,24 @@ def test_dataset_init_success(location, expected_file_system, expected_source):
2327
assert timdex_dataset.source == expected_source
2428

2529

30+
def test_dataset_init_env_vars_set_config(monkeypatch, local_dataset_location):
31+
default_timdex_dataset = TIMDEXDataset(location=local_dataset_location)
32+
default_read_batch_config = default_timdex_dataset.config.read_batch_size
33+
assert default_read_batch_config == 1_000
34+
35+
monkeypatch.setenv("TDA_READ_BATCH_SIZE", "100_000")
36+
env_var_timdex_dataset = TIMDEXDataset(location=local_dataset_location)
37+
env_var_read_batch_config = env_var_timdex_dataset.config.read_batch_size
38+
assert env_var_read_batch_config == 100_000
39+
40+
41+
def test_dataset_init_custom_config_object(monkeypatch, local_dataset_location):
42+
config = TIMDEXDatasetConfig()
43+
config.max_rows_per_file = 42
44+
timdex_dataset = TIMDEXDataset(location=local_dataset_location, config=config)
45+
assert timdex_dataset.config.max_rows_per_file == 42
46+
47+
2648
@patch("timdex_dataset_api.dataset.fs.LocalFileSystem")
2749
@patch("timdex_dataset_api.dataset.ds.dataset")
2850
def test_dataset_load_local_sets_filesystem_and_dataset_success(
@@ -73,28 +95,28 @@ def test_dataset_load_without_filters_success(fixed_local_dataset):
7395
fixed_local_dataset.load()
7496

7597
assert os.path.exists(fixed_local_dataset.location)
76-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
98+
assert fixed_local_dataset.row_count == 5_000
7799

78100

79101
def test_dataset_load_with_run_date_str_filters_success(fixed_local_dataset):
80102
fixed_local_dataset.load(run_date="2024-12-01")
81103

82104
assert os.path.exists(fixed_local_dataset.location)
83-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
105+
assert fixed_local_dataset.row_count == 5_000
84106

85107

86108
def test_dataset_load_with_run_date_obj_filters_success(fixed_local_dataset):
87109
fixed_local_dataset.load(run_date=date(2024, 12, 1))
88110

89111
assert os.path.exists(fixed_local_dataset.location)
90-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
112+
assert fixed_local_dataset.row_count == 5_000
91113

92114

93115
def test_dataset_load_with_ymd_filters_success(fixed_local_dataset):
94116
fixed_local_dataset.load(year="2024", month="12", day="01")
95117

96118
assert os.path.exists(fixed_local_dataset.location)
97-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
119+
assert fixed_local_dataset.row_count == 5_000
98120

99121

100122
def test_dataset_load_with_single_nonpartition_filters_success(fixed_local_dataset):
@@ -158,7 +180,7 @@ def test_dataset_get_filtered_dataset_with_or_nonpartition_filters_success(
158180
timdex_record_id=["alma:0", "alma:1"]
159181
)
160182
filtered_local_df = filtered_local_dataset.to_table().to_pandas()
161-
assert len(filtered_local_df) == 2 # noqa: PLR2004
183+
assert len(filtered_local_df) == 2
162184
assert filtered_local_df["timdex_record_id"].tolist() == ["alma:0", "alma:1"]
163185

164186

tests/test_write.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from tests.utils import generate_sample_records
1010
from timdex_dataset_api.dataset import (
11-
MAX_ROWS_PER_FILE,
1211
TIMDEX_DATASET_SCHEMA,
1312
TIMDEXDataset,
1413
)
@@ -28,28 +27,29 @@ def test_dataset_write_records_to_new_local_dataset(
2827
def test_dataset_write_default_max_rows_per_file(new_local_dataset, sample_records_iter):
2928
"""Default is 100k rows per file, therefore writing 200,033 records should result in
3029
3 files (x2 @ 100k rows, x1 @ 33 rows)."""
30+
default_max_rows_per_file = new_local_dataset.config.max_rows_per_file
3131
total_records = 200_033
3232

3333
new_local_dataset.write(sample_records_iter(total_records))
3434
new_local_dataset.load()
3535

3636
assert new_local_dataset.row_count == total_records
3737
assert len(new_local_dataset.dataset.files) == math.ceil(
38-
total_records / MAX_ROWS_PER_FILE
38+
total_records / default_max_rows_per_file
3939
)
4040

4141

4242
def test_dataset_write_record_batches_uses_batch_size(
4343
new_local_dataset, sample_records_iter
4444
):
4545
total_records = 101
46-
batch_size = 50
46+
new_local_dataset.config.write_batch_size = 50
4747
batches = list(
48-
new_local_dataset.create_record_batches(
49-
sample_records_iter(total_records), batch_size=batch_size
50-
)
48+
new_local_dataset.create_record_batches(sample_records_iter(total_records))
49+
)
50+
assert len(batches) == math.ceil(
51+
total_records / new_local_dataset.config.write_batch_size
5152
)
52-
assert len(batches) == math.ceil(total_records / batch_size)
5353

5454

5555
def test_dataset_write_to_multiple_locations_raise_error(sample_records_iter):

timdex_dataset_api/dataset.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import itertools
44
import json
55
import operator
6+
import os
67
import time
78
import uuid
89
from collections.abc import Iterator
10+
from dataclasses import dataclass, field
911
from datetime import UTC, date, datetime
1012
from functools import reduce
1113
from typing import TYPE_CHECKING, TypedDict, Unpack
@@ -61,27 +63,57 @@ class DatasetFilters(TypedDict, total=False):
6163
day: str | None
6264

6365

64-
DEFAULT_BATCH_SIZE = 1_000
65-
MAX_ROWS_PER_GROUP = DEFAULT_BATCH_SIZE
66-
MAX_ROWS_PER_FILE = 100_000
67-
DEFAULT_BATCH_READ_AHEAD = 0
68-
DEFAULT_FRAGMENT_READ_AHEAD = 0
66+
@dataclass
67+
class TIMDEXDatasetConfig:
68+
"""Configurations for dataset operations.
6969
70+
- read_batch_size: row size of batches read, affecting memory consumption
71+
- write_batch_size: row size of batches written, directly affecting row group size in
72+
final parquet files
73+
- max_rows_per_group: max number of rows per row group in a parquet file
74+
- max_rows_per_file: max number of rows in a single parquet file
75+
- batch_read_ahead: number of batches to optimistically read ahead when batch reading
76+
from a dataset; pyarrow default is 16
77+
- fragment_read_ahead: number of fragments to optimistically read ahead when batch
78+
reaching from a dataset; pyarrow default is 4
79+
"""
7080

71-
def strict_date_parse(date_string: str) -> date:
72-
return datetime.strptime(date_string, "%Y-%m-%d").astimezone(UTC).date()
81+
read_batch_size: int = field(
82+
default_factory=lambda: int(os.getenv("TDA_READ_BATCH_SIZE", "1_000"))
83+
)
84+
write_batch_size: int = field(
85+
default_factory=lambda: int(os.getenv("TDA_WRITE_BATCH_SIZE", "1_000"))
86+
)
87+
max_rows_per_group: int = field(
88+
default_factory=lambda: int(os.getenv("TDA_MAX_ROWS_PER_GROUP", "1_000"))
89+
)
90+
max_rows_per_file: int = field(
91+
default_factory=lambda: int(os.getenv("TDA_MAX_ROWS_PER_FILE", "100_000"))
92+
)
93+
batch_read_ahead: int = field(
94+
default_factory=lambda: int(os.getenv("TDA_BATCH_READ_AHEAD", "0"))
95+
)
96+
fragment_read_ahead: int = field(
97+
default_factory=lambda: int(os.getenv("TDA_FRAGMENT_READ_AHEAD", "0"))
98+
)
7399

74100

75101
class TIMDEXDataset:
76102

77-
def __init__(self, location: str | list[str]):
103+
def __init__(
104+
self,
105+
location: str | list[str],
106+
config: TIMDEXDatasetConfig | None = None,
107+
):
78108
"""Initialize TIMDEXDataset object.
79109
80110
Args:
81111
location (str | list[str]): Local filesystem path or an S3 URI to
82112
a parquet dataset. For partitioned datasets, set to the base directory.
83113
"""
84114
self.location = location
115+
self.config = config or TIMDEXDatasetConfig()
116+
85117
self.filesystem, self.source = self.parse_location(self.location)
86118
self.dataset: ds.Dataset = None # type: ignore[assignment]
87119
self.schema = TIMDEX_DATASET_SCHEMA
@@ -171,7 +203,7 @@ def _get_filtered_dataset(
171203

172204
# create filter expressions for element-wise equality comparisons
173205
expressions = []
174-
for field, value in filters.items():
206+
for field, value in filters.items(): # noqa: F402
175207
if isinstance(value, list):
176208
expressions.append(ds.field(field).isin(value))
177209
else:
@@ -207,7 +239,7 @@ def _parse_date_filters(self, run_date: str | date | None) -> DatasetFilters:
207239
DatasetFilters[dict]: values for run_date, year, month, and day
208240
"""
209241
if isinstance(run_date, str):
210-
run_date_obj = strict_date_parse(run_date)
242+
run_date_obj = datetime.strptime(run_date, "%Y-%m-%d").astimezone(UTC).date()
211243
elif isinstance(run_date, date):
212244
run_date_obj = run_date
213245
else:
@@ -286,7 +318,6 @@ def write(
286318
self,
287319
records_iter: Iterator["DatasetRecord"],
288320
*,
289-
batch_size: int = DEFAULT_BATCH_SIZE,
290321
use_threads: bool = True,
291322
) -> list[ds.WrittenFile]:
292323
"""Write records to the TIMDEX parquet dataset.
@@ -309,8 +340,6 @@ def write(
309340
310341
Args:
311342
- records_iter: Iterator of DatasetRecord instances
312-
- batch_size: size for batches to yield and write, directly affecting row
313-
group size in final parquet files
314343
- use_threads: boolean if threads should be used for writing
315344
"""
316345
start_time = time.perf_counter()
@@ -321,10 +350,7 @@ def write(
321350
"Dataset location must be the root of a single dataset for writing"
322351
)
323352

324-
record_batches_iter = self.create_record_batches(
325-
records_iter,
326-
batch_size=batch_size,
327-
)
353+
record_batches_iter = self.create_record_batches(records_iter)
328354

329355
ds.write_dataset(
330356
record_batches_iter,
@@ -335,8 +361,8 @@ def write(
335361
file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type]
336362
format="parquet",
337363
max_open_files=500,
338-
max_rows_per_file=MAX_ROWS_PER_FILE,
339-
max_rows_per_group=MAX_ROWS_PER_GROUP,
364+
max_rows_per_file=self.config.max_rows_per_file,
365+
max_rows_per_group=self.config.max_rows_per_group,
340366
partitioning=self.partition_columns,
341367
partitioning_flavor="hive",
342368
schema=self.schema,
@@ -349,8 +375,6 @@ def write(
349375
def create_record_batches(
350376
self,
351377
records_iter: Iterator["DatasetRecord"],
352-
*,
353-
batch_size: int = DEFAULT_BATCH_SIZE,
354378
) -> Iterator[pa.RecordBatch]:
355379
"""Yield pyarrow.RecordBatches for writing.
356380
@@ -361,10 +385,10 @@ def create_record_batches(
361385
362386
Args:
363387
- records_iter: Iterator of DatasetRecord instances
364-
- batch_size: size for batches to yield and write, directly affecting row
365-
group size in final parquet files
366388
"""
367-
for i, record_batch in enumerate(itertools.batched(records_iter, batch_size)):
389+
for i, record_batch in enumerate(
390+
itertools.batched(records_iter, self.config.write_batch_size)
391+
):
368392
batch = pa.RecordBatch.from_pylist(
369393
[record.to_dict() for record in record_batch]
370394
)
@@ -395,9 +419,6 @@ def log_write_statistics(self, start_time: float) -> None:
395419
def read_batches_iter(
396420
self,
397421
columns: list[str] | None = None,
398-
batch_size: int = DEFAULT_BATCH_SIZE,
399-
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
400-
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
401422
**filters: Unpack[DatasetFilters],
402423
) -> Iterator[pa.RecordBatch]:
403424
"""Yield pyarrow.RecordBatches from the dataset.
@@ -416,7 +437,7 @@ def read_batches_iter(
416437
number will increase RAM usage but could also improve IO utilization.
417438
Pyarrow default is 4, but this library defaults to 0 to prioritize memory
418439
footprint.
419-
- filter_kwargs: pairs of column:value to filter the dataset
440+
- filters: pairs of column:value to filter the dataset
420441
"""
421442
if not self.dataset:
422443
raise DatasetNotLoadedError(
@@ -425,19 +446,16 @@ def read_batches_iter(
425446
dataset = self._get_filtered_dataset(**filters)
426447
for batch in dataset.to_batches(
427448
columns=columns,
428-
batch_size=batch_size,
429-
batch_readahead=batch_read_ahead,
430-
fragment_readahead=fragment_read_ahead,
449+
batch_size=self.config.read_batch_size,
450+
batch_readahead=self.config.batch_read_ahead,
451+
fragment_readahead=self.config.fragment_read_ahead,
431452
):
432453
if len(batch) > 0:
433454
yield batch
434455

435456
def read_dataframes_iter(
436457
self,
437458
columns: list[str] | None = None,
438-
batch_size: int = DEFAULT_BATCH_SIZE,
439-
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
440-
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
441459
**filters: Unpack[DatasetFilters],
442460
) -> Iterator[pd.DataFrame]:
443461
"""Yield record batches as Pandas DataFrames from the dataset.
@@ -446,19 +464,13 @@ def read_dataframes_iter(
446464
"""
447465
for record_batch in self.read_batches_iter(
448466
columns=columns,
449-
batch_size=batch_size,
450-
batch_read_ahead=batch_read_ahead,
451-
fragment_read_ahead=fragment_read_ahead,
452467
**filters,
453468
):
454469
yield record_batch.to_pandas()
455470

456471
def read_dataframe(
457472
self,
458473
columns: list[str] | None = None,
459-
batch_size: int = DEFAULT_BATCH_SIZE,
460-
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
461-
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
462474
**filters: Unpack[DatasetFilters],
463475
) -> pd.DataFrame | None:
464476
"""Yield record batches as Pandas DataFrames and concatenate to single dataframe.
@@ -473,9 +485,6 @@ def read_dataframe(
473485
record_batch.to_pandas()
474486
for record_batch in self.read_batches_iter(
475487
columns=columns,
476-
batch_size=batch_size,
477-
batch_read_ahead=batch_read_ahead,
478-
fragment_read_ahead=fragment_read_ahead,
479488
**filters,
480489
)
481490
]
@@ -486,9 +495,6 @@ def read_dataframe(
486495
def read_dicts_iter(
487496
self,
488497
columns: list[str] | None = None,
489-
batch_size: int = DEFAULT_BATCH_SIZE,
490-
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
491-
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
492498
**filters: Unpack[DatasetFilters],
493499
) -> Iterator[dict]:
494500
"""Yield individual record rows as dictionaries from the dataset.
@@ -497,18 +503,12 @@ def read_dicts_iter(
497503
"""
498504
for record_batch in self.read_batches_iter(
499505
columns=columns,
500-
batch_size=batch_size,
501-
batch_read_ahead=batch_read_ahead,
502-
fragment_read_ahead=fragment_read_ahead,
503506
**filters,
504507
):
505508
yield from record_batch.to_pylist()
506509

507510
def read_transformed_records_iter(
508511
self,
509-
batch_size: int = DEFAULT_BATCH_SIZE,
510-
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
511-
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
512512
**filters: Unpack[DatasetFilters],
513513
) -> Iterator[dict]:
514514
"""Yield individual transformed records as dictionaries from the dataset.
@@ -520,9 +520,6 @@ def read_transformed_records_iter(
520520
"""
521521
for record_dict in self.read_dicts_iter(
522522
columns=["transformed_record"],
523-
batch_size=batch_size,
524-
batch_read_ahead=batch_read_ahead,
525-
fragment_read_ahead=fragment_read_ahead,
526523
**filters,
527524
):
528525
if transformed_record := record_dict["transformed_record"]:

0 commit comments

Comments
 (0)