Skip to content

Commit b6043f4

Browse files
authored
Merge pull request #9 from MITLibraries/TIMX-415-write-to-dataset
TIMX 415 - write to dataset
2 parents 9f9477a + b346fb3 commit b6043f4

9 files changed

Lines changed: 582 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ ignore = [
8686
"D103",
8787
"D104",
8888
"D415",
89+
"D417",
8990
"EM102",
9091
"G004",
9192
"PLR0912",

tests/conftest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""tests/conftest.py"""
22

3+
# ruff: noqa: D205, D209
4+
5+
36
import pytest
47

8+
from tests.utils import generate_sample_records
59
from timdex_dataset_api import TIMDEXDataset
610

711

@@ -22,3 +26,37 @@ def local_dataset_location():
2226
@pytest.fixture
2327
def local_dataset(local_dataset_location):
2428
return TIMDEXDataset.load(local_dataset_location)
29+
30+
31+
@pytest.fixture
32+
def new_dataset(tmp_path) -> TIMDEXDataset:
33+
location = str(tmp_path / "new_dataset")
34+
return TIMDEXDataset(location=location)
35+
36+
37+
@pytest.fixture
38+
def sample_records_iter():
39+
"""Simulates an iterator of X number of valid DatasetRecord instances."""
40+
41+
def _records_iter(num_records):
42+
return generate_sample_records(num_records)
43+
44+
return _records_iter
45+
46+
47+
@pytest.fixture
48+
def sample_records_iter_without_partitions():
49+
"""Simulates an iterator of X number of DatasetRecord instances WITHOUT partition
50+
values included."""
51+
52+
def _records_iter(num_records):
53+
return generate_sample_records(
54+
num_records,
55+
source=None,
56+
run_date=None,
57+
run_type=None,
58+
action=None,
59+
run_id=None,
60+
)
61+
62+
return _records_iter

tests/test_dataset.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa: S105, S106, SLF001
2-
32
from unittest.mock import MagicMock, patch
43

54
import pyarrow as pa
@@ -33,7 +32,7 @@
3332
],
3433
)
3534
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
36-
def test_parse_location_single_local_directory(
35+
def test_parse_location_success_scenarios(
3736
get_s3_filesystem,
3837
location,
3938
expected_filesystem,
@@ -45,6 +44,28 @@ def test_parse_location_single_local_directory(
4544
assert source == expected_source
4645

4746

47+
@pytest.mark.parametrize(
48+
("location", "expected_exception"),
49+
[
50+
# None is invalid location type
51+
(None, TypeError),
52+
# mixed local and S3 locations
53+
(
54+
[
55+
"/local/path/to/dataset/records.parquet",
56+
"s3://path/to/dataset/records.parquet",
57+
],
58+
ValueError,
59+
),
60+
],
61+
)
62+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
63+
def test_parse_location_error_scenarios(get_s3_filesystem, location, expected_exception):
64+
get_s3_filesystem.return_value = fs.S3FileSystem()
65+
with pytest.raises(expected_exception):
66+
_ = TIMDEXDataset.parse_location(location)
67+
68+
4869
def test_get_s3_filesystem_success(mocker):
4970
mocked_s3_filesystem = mocker.spy(fs, "S3FileSystem")
5071
s3_filesystem = TIMDEXDataset.get_s3_filesystem()

tests/test_dataset_write.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# ruff: noqa: S105, S106, SLF001, PLR2004, PD901, D209, D205
2+
3+
import datetime
4+
import math
5+
import os
6+
7+
import pyarrow.dataset as ds
8+
import pytest
9+
10+
from timdex_dataset_api.dataset import (
11+
MAX_ROWS_PER_FILE,
12+
TIMDEX_DATASET_SCHEMA,
13+
DatasetNotLoadedError,
14+
TIMDEXDataset,
15+
)
16+
from timdex_dataset_api.exceptions import InvalidDatasetRecordError
17+
from timdex_dataset_api.record import DatasetRecord
18+
19+
20+
def test_dataset_record_serialization():
21+
values = {
22+
"timdex_record_id": "alma:123",
23+
"source_record": b"<record><title>Hello World.</title></record>",
24+
"transformed_record": b"""{"title":["Hello World."]}""",
25+
"source": "libguides",
26+
"run_date": "2024-12-01",
27+
"run_type": "full",
28+
"action": "index",
29+
"run_id": "abc123",
30+
}
31+
dataset_record = DatasetRecord(**values)
32+
assert dataset_record.to_dict() == values
33+
34+
35+
def test_dataset_record_serialization_with_partition_values_provided():
36+
dataset_record = DatasetRecord(
37+
timdex_record_id="alma:123",
38+
source_record=b"<record><title>Hello World.</title></record>",
39+
transformed_record=b"""{"title":["Hello World."]}""",
40+
)
41+
partition_values = {
42+
"source": "alma",
43+
"run_date": "2024-12-01",
44+
"run_type": "daily",
45+
"action": "index",
46+
"run_id": "000-111-aaa-bbb",
47+
}
48+
assert dataset_record.to_dict(partition_values=partition_values) == {
49+
"timdex_record_id": "alma:123",
50+
"source_record": b"<record><title>Hello World.</title></record>",
51+
"transformed_record": b"""{"title":["Hello World."]}""",
52+
"source": "alma",
53+
"run_date": "2024-12-01",
54+
"run_type": "daily",
55+
"action": "index",
56+
"run_id": "000-111-aaa-bbb",
57+
}
58+
59+
60+
def test_dataset_record_serialization_missing_partition_raise_error():
61+
values = {
62+
"timdex_record_id": "alma:123",
63+
"source_record": b"<record><title>Hello World.</title></record>",
64+
"transformed_record": b"""{"title":["Hello World."]}""",
65+
"source": "libguides",
66+
"run_date": "2024-12-01",
67+
"run_type": "full",
68+
"action": "index",
69+
"run_id": None, # <------ missing partition here
70+
}
71+
dataset_record = DatasetRecord(**values)
72+
with pytest.raises(
73+
InvalidDatasetRecordError,
74+
match="Partition values are missing: run_id",
75+
):
76+
assert dataset_record.to_dict() == values
77+
78+
79+
def test_dataset_write_records_to_new_dataset(new_dataset, sample_records_iter):
80+
files_written = new_dataset.write(sample_records_iter(10_000))
81+
assert len(files_written) == 1
82+
assert os.path.exists(new_dataset.location)
83+
84+
# load newly created dataset as new TIMDEXDataset instance
85+
dataset = TIMDEXDataset.load(new_dataset.location)
86+
assert dataset.row_count == 10_000
87+
88+
89+
def test_dataset_reload_after_write(new_dataset, sample_records_iter):
90+
files_written = new_dataset.write(sample_records_iter(10_000))
91+
assert len(files_written) == 1
92+
assert os.path.exists(new_dataset.location)
93+
94+
# attempt row count before reload
95+
with pytest.raises(DatasetNotLoadedError):
96+
_ = new_dataset.row_count
97+
98+
# attempt row count after reload
99+
new_dataset.reload()
100+
assert new_dataset.row_count == 10_000
101+
102+
103+
def test_dataset_write_default_max_rows_per_file(new_dataset, sample_records_iter):
104+
"""Default is 100k rows per file, therefore writing 200,033 records should result in
105+
3 files (x2 @ 100k rows, x1 @ 33 rows)."""
106+
total_records = 200_033
107+
108+
new_dataset.write(sample_records_iter(total_records))
109+
new_dataset.reload()
110+
111+
assert new_dataset.row_count == total_records
112+
assert len(new_dataset.dataset.files) == math.ceil(total_records / MAX_ROWS_PER_FILE)
113+
114+
115+
def test_dataset_write_record_batches_uses_batch_size(new_dataset, sample_records_iter):
116+
total_records = 101
117+
batch_size = 50
118+
batches = list(
119+
new_dataset.get_dataset_record_batches(
120+
sample_records_iter(total_records), batch_size=batch_size
121+
)
122+
)
123+
assert len(batches) == math.ceil(total_records / batch_size)
124+
125+
126+
def test_dataset_write_to_multiple_locations_raise_error(sample_records_iter):
127+
timdex_dataset = TIMDEXDataset(
128+
location=["/path/to/records-1.parquet", "/path/to/records-2.parquet"]
129+
)
130+
with pytest.raises(
131+
TypeError,
132+
match="Dataset location must be the root of a single dataset for writing",
133+
):
134+
timdex_dataset.write(sample_records_iter(10))
135+
136+
137+
def test_dataset_write_mixin_partition_values_used(
138+
new_dataset, sample_records_iter_without_partitions
139+
):
140+
partition_values = {
141+
"source": "alma",
142+
"run_date": "2024-12-01",
143+
"run_type": "daily",
144+
"action": "index",
145+
"run_id": "000-111-aaa-bbb",
146+
}
147+
_written_files = new_dataset.write(
148+
sample_records_iter_without_partitions(10),
149+
partition_values=partition_values,
150+
)
151+
new_dataset.reload()
152+
153+
# load as pandas dataframe and assert column values
154+
df = new_dataset.dataset.to_table().to_pandas()
155+
row = df.iloc[0]
156+
assert row.source == partition_values["source"]
157+
assert row.run_date == datetime.date(2024, 12, 1)
158+
assert row.run_type == partition_values["run_type"]
159+
assert row.action == partition_values["action"]
160+
assert row.action == partition_values["action"]
161+
162+
163+
def test_dataset_write_schema_partitions_correctly_ordered(
164+
new_dataset, sample_records_iter
165+
):
166+
written_files = new_dataset.write(
167+
sample_records_iter(10),
168+
partition_values={
169+
"source": "alma",
170+
"run_date": "2024-12-01",
171+
"run_type": "daily",
172+
"action": "index",
173+
"run_id": "000-111-aaa-bbb",
174+
},
175+
)
176+
file = written_files[0]
177+
assert (
178+
"/source=alma/run_date=2024-12-01/run_type=daily"
179+
"/action=index/run_id=000-111-aaa-bbb" in file.path
180+
)
181+
182+
183+
def test_dataset_write_schema_applied_to_dataset(new_dataset, sample_records_iter):
184+
new_dataset.write(sample_records_iter(10))
185+
186+
# manually load dataset to confirm schema without TIMDEXDataset projecting schema
187+
# during load
188+
dataset = ds.dataset(
189+
new_dataset.location,
190+
format="parquet",
191+
partitioning="hive",
192+
)
193+
194+
assert set(dataset.schema.names) == set(TIMDEX_DATASET_SCHEMA.names)
195+
196+
197+
def test_dataset_write_partition_deleted_when_written_to_again(
198+
new_dataset, sample_records_iter
199+
):
200+
"""This tests the existing_data_behavior="delete_matching" configuration when writing
201+
to a dataset."""
202+
partition_values = {
203+
"source": "alma",
204+
"run_date": "2024-12-01",
205+
"run_type": "daily",
206+
"action": "index",
207+
"run_id": "000-111-aaa-bbb",
208+
}
209+
210+
# perform FIRST write to run_date="2024-12-01"
211+
written_files_1 = new_dataset.write(
212+
sample_records_iter(10),
213+
partition_values=partition_values,
214+
)
215+
216+
# assert that files from first write are present at this time
217+
assert os.path.exists(written_files_1[0].path)
218+
219+
# perform unrelated write with new run_date to confirm this is untouched during delete
220+
new_partition_values = partition_values.copy()
221+
new_partition_values["run_date"] = "2024-12-15"
222+
new_partition_values["run_id"] = "222-333-ccc-ddd"
223+
written_files_x = new_dataset.write(
224+
sample_records_iter(7),
225+
partition_values=new_partition_values,
226+
)
227+
228+
# perform SECOND write to run_date="2024-12-01", expecting this to delete everything
229+
# under this combination of partitions (i.e. the first write)
230+
written_files_2 = new_dataset.write(
231+
sample_records_iter(10),
232+
partition_values=partition_values,
233+
)
234+
235+
new_dataset.reload()
236+
237+
# assert 17 rows: second write for run_date="2024-12-01" @ 10 rows +
238+
# run_date="2024-12-15" @ 5 rows
239+
assert new_dataset.row_count == 17
240+
241+
# assert that files from first run_date="2024-12-01" are gone, second exist
242+
# and files from run_date="2024-12-15" also exist
243+
assert not os.path.exists(written_files_1[0].path)
244+
assert os.path.exists(written_files_2[0].path)
245+
assert os.path.exists(written_files_x[0].path)
246+
247+
248+
def test_dataset_write_missing_partitions_raise_error(new_dataset, sample_records_iter):
249+
missing_partition_values = {
250+
"source": "libguides",
251+
"run_date": None,
252+
"run_type": None,
253+
"action": None,
254+
"run_id": None,
255+
}
256+
with pytest.raises(InvalidDatasetRecordError, match="Partition values are missing"):
257+
_ = new_dataset.write(
258+
sample_records_iter(10),
259+
partition_values=missing_partition_values,
260+
)

0 commit comments

Comments
 (0)