Skip to content

Commit bd408dc

Browse files
authored
Merge pull request #108 from MITLibraries/TIMX-468-read-configs
TIMX 468 - surface pyarrow read configs to address memory footprint
2 parents 41b326f + 3d5d57b commit bd408dc

4 files changed

Lines changed: 110 additions & 48 deletions

File tree

tests/test_dataset.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ruff: noqa: S105, S106, SLF001
1+
# ruff: noqa: S105, S106, SLF001, PLR2004
22
import os
33
from datetime import date
44
from unittest.mock import MagicMock, patch
@@ -7,7 +7,11 @@
77
import pytest
88
from pyarrow import fs
99

10-
from timdex_dataset_api.dataset import DatasetNotLoadedError, TIMDEXDataset
10+
from timdex_dataset_api.dataset import (
11+
DatasetNotLoadedError,
12+
TIMDEXDataset,
13+
TIMDEXDatasetConfig,
14+
)
1115

1216

1317
@pytest.mark.parametrize(
@@ -23,6 +27,24 @@ def test_dataset_init_success(location, expected_file_system, expected_source):
2327
assert timdex_dataset.source == expected_source
2428

2529

30+
def test_dataset_init_env_vars_set_config(monkeypatch, local_dataset_location):
31+
default_timdex_dataset = TIMDEXDataset(location=local_dataset_location)
32+
default_read_batch_config = default_timdex_dataset.config.read_batch_size
33+
assert default_read_batch_config == 1_000
34+
35+
monkeypatch.setenv("TDA_READ_BATCH_SIZE", "100_000")
36+
env_var_timdex_dataset = TIMDEXDataset(location=local_dataset_location)
37+
env_var_read_batch_config = env_var_timdex_dataset.config.read_batch_size
38+
assert env_var_read_batch_config == 100_000
39+
40+
41+
def test_dataset_init_custom_config_object(monkeypatch, local_dataset_location):
42+
config = TIMDEXDatasetConfig()
43+
config.max_rows_per_file = 42
44+
timdex_dataset = TIMDEXDataset(location=local_dataset_location, config=config)
45+
assert timdex_dataset.config.max_rows_per_file == 42
46+
47+
2648
@patch("timdex_dataset_api.dataset.fs.LocalFileSystem")
2749
@patch("timdex_dataset_api.dataset.ds.dataset")
2850
def test_dataset_load_local_sets_filesystem_and_dataset_success(
@@ -73,28 +95,28 @@ def test_dataset_load_without_filters_success(fixed_local_dataset):
7395
fixed_local_dataset.load()
7496

7597
assert os.path.exists(fixed_local_dataset.location)
76-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
98+
assert fixed_local_dataset.row_count == 5_000
7799

78100

79101
def test_dataset_load_with_run_date_str_filters_success(fixed_local_dataset):
80102
fixed_local_dataset.load(run_date="2024-12-01")
81103

82104
assert os.path.exists(fixed_local_dataset.location)
83-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
105+
assert fixed_local_dataset.row_count == 5_000
84106

85107

86108
def test_dataset_load_with_run_date_obj_filters_success(fixed_local_dataset):
87109
fixed_local_dataset.load(run_date=date(2024, 12, 1))
88110

89111
assert os.path.exists(fixed_local_dataset.location)
90-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
112+
assert fixed_local_dataset.row_count == 5_000
91113

92114

93115
def test_dataset_load_with_ymd_filters_success(fixed_local_dataset):
94116
fixed_local_dataset.load(year="2024", month="12", day="01")
95117

96118
assert os.path.exists(fixed_local_dataset.location)
97-
assert fixed_local_dataset.row_count == 5_000 # noqa: PLR2004
119+
assert fixed_local_dataset.row_count == 5_000
98120

99121

100122
def test_dataset_load_with_single_nonpartition_filters_success(fixed_local_dataset):
@@ -158,7 +180,7 @@ def test_dataset_get_filtered_dataset_with_or_nonpartition_filters_success(
158180
timdex_record_id=["alma:0", "alma:1"]
159181
)
160182
filtered_local_df = filtered_local_dataset.to_table().to_pandas()
161-
assert len(filtered_local_df) == 2 # noqa: PLR2004
183+
assert len(filtered_local_df) == 2
162184
assert filtered_local_df["timdex_record_id"].tolist() == ["alma:0", "alma:1"]
163185

164186

tests/test_write.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from tests.utils import generate_sample_records
1010
from timdex_dataset_api.dataset import (
11-
MAX_ROWS_PER_FILE,
1211
TIMDEX_DATASET_SCHEMA,
1312
TIMDEXDataset,
1413
)
@@ -28,28 +27,29 @@ def test_dataset_write_records_to_new_local_dataset(
2827
def test_dataset_write_default_max_rows_per_file(new_local_dataset, sample_records_iter):
2928
"""Default is 100k rows per file, therefore writing 200,033 records should result in
3029
3 files (x2 @ 100k rows, x1 @ 33 rows)."""
30+
default_max_rows_per_file = new_local_dataset.config.max_rows_per_file
3131
total_records = 200_033
3232

3333
new_local_dataset.write(sample_records_iter(total_records))
3434
new_local_dataset.load()
3535

3636
assert new_local_dataset.row_count == total_records
3737
assert len(new_local_dataset.dataset.files) == math.ceil(
38-
total_records / MAX_ROWS_PER_FILE
38+
total_records / default_max_rows_per_file
3939
)
4040

4141

4242
def test_dataset_write_record_batches_uses_batch_size(
4343
new_local_dataset, sample_records_iter
4444
):
4545
total_records = 101
46-
batch_size = 50
46+
new_local_dataset.config.write_batch_size = 50
4747
batches = list(
48-
new_local_dataset.create_record_batches(
49-
sample_records_iter(total_records), batch_size=batch_size
50-
)
48+
new_local_dataset.create_record_batches(sample_records_iter(total_records))
49+
)
50+
assert len(batches) == math.ceil(
51+
total_records / new_local_dataset.config.write_batch_size
5152
)
52-
assert len(batches) == math.ceil(total_records / batch_size)
5353

5454

5555
def test_dataset_write_to_multiple_locations_raise_error(sample_records_iter):

timdex_dataset_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from timdex_dataset_api.dataset import TIMDEXDataset
44
from timdex_dataset_api.record import DatasetRecord
55

6-
__version__ = "0.9.0"
6+
__version__ = "0.10.0"
77

88
__all__ = [
99
"DatasetRecord",

0 commit comments

Comments
 (0)