Skip to content

Commit 23b42d8

Browse files
committed
Add baseline read methods to TIMDEXDataset
Why these changes are being introduced: A primary responsibility of the TIMDEXDataset class is to provide performant and memory efficient reading of a dataset. It is anticipated that additional read methods may be required, for specific or niche situations, but some simple baseline ones are needed at this time. How this addresses that need: * Adds methods for reading pyarrow batches, pandas dataframes, and python dictionaries from a dataset. Side effects of this change: * Applications like timdex lambdas or TIM can now read records from dataset Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-417
1 parent 7251258 commit 23b42d8

3 files changed

Lines changed: 188 additions & 1 deletion

File tree

tests/conftest.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,23 @@ def fixed_local_dataset(tmp_path) -> TIMDEXDataset:
4949
method.
5050
"""
5151
timdex_dataset = TIMDEXDataset(str(tmp_path / "fixed_local_dataset/"))
52-
timdex_dataset.write(generate_sample_records(num_records=5_000, run_id="abc123"))
52+
for source, run_id in [
53+
("alma", "abc123"),
54+
("dspace", "def456"),
55+
("aspace", "ghi789"),
56+
("libguides", "jkl123"),
57+
("gismit", "mno456"),
58+
]:
59+
timdex_dataset.write(
60+
generate_sample_records(
61+
num_records=1_000,
62+
timdex_record_id_prefix=source,
63+
source=source,
64+
run_date="2024-12-01",
65+
run_id=run_id,
66+
)
67+
)
68+
timdex_dataset.load()
5369
return timdex_dataset
5470

5571

tests/test_dataset_read.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# ruff: noqa: PLR2004, PD901
2+
3+
import pandas as pd
4+
import pyarrow as pa
5+
import pytest
6+
7+
DATASET_COLUMNS_SET = {
8+
"timdex_record_id",
9+
"source_record",
10+
"transformed_record",
11+
"source",
12+
"run_date",
13+
"run_type",
14+
"run_id",
15+
"action",
16+
"year",
17+
"month",
18+
"day",
19+
}
20+
21+
22+
def test_read_batches_yields_pyarrow_record_batches(fixed_local_dataset):
23+
batches = fixed_local_dataset.read_batches_iter()
24+
batch = next(batches)
25+
assert isinstance(batch, pa.RecordBatch)
26+
27+
28+
def test_read_batches_all_columns_by_default(fixed_local_dataset):
29+
batches = fixed_local_dataset.read_batches_iter()
30+
batch = next(batches)
31+
assert set(batch.column_names) == DATASET_COLUMNS_SET
32+
33+
34+
def test_read_batches_filter_columns(fixed_local_dataset):
35+
columns_subset = ["source", "transformed_record"]
36+
batches = fixed_local_dataset.read_batches_iter(columns=columns_subset)
37+
batch = next(batches)
38+
assert set(batch.column_names) == set(columns_subset)
39+
40+
41+
def test_read_batches_no_filters_gets_full_dataset(fixed_local_dataset):
42+
batches = fixed_local_dataset.read_batches_iter()
43+
table = pa.Table.from_batches(batches)
44+
assert len(table) == fixed_local_dataset.row_count
45+
46+
47+
def test_read_batches_with_filters_gets_subset_of_dataset(fixed_local_dataset):
48+
batches = fixed_local_dataset.read_batches_iter(
49+
source="libguides",
50+
run_date="2024-12-01",
51+
run_type="daily",
52+
action="index",
53+
)
54+
55+
table = pa.Table.from_batches(batches)
56+
assert len(table) == 1_000
57+
assert len(table) < fixed_local_dataset.row_count
58+
59+
# assert loaded dataset is unchanged by filtering for a read method
60+
assert fixed_local_dataset.row_count == 5_000
61+
62+
63+
def test_read_dataframe_batches_yields_dataframes(fixed_local_dataset):
64+
df_iter = fixed_local_dataset.read_dataframes_iter()
65+
df_batch = next(df_iter)
66+
assert isinstance(df_batch, pd.DataFrame)
67+
assert len(df_batch) == 1_000
68+
69+
70+
def test_read_dataframe_reads_all_dataset_rows_after_filtering(fixed_local_dataset):
71+
df = fixed_local_dataset.read_dataframe()
72+
assert isinstance(df, pd.DataFrame)
73+
assert len(df) == fixed_local_dataset.row_count
74+
75+
76+
def test_read_dicts_yields_dictionary_for_each_dataset_record(fixed_local_dataset):
77+
records = fixed_local_dataset.read_dicts_iter()
78+
record = next(records)
79+
assert isinstance(record, dict)
80+
assert set(record.keys()) == DATASET_COLUMNS_SET
81+
82+
83+
def test_read_batches_filter_to_none_returns_empty_list(fixed_local_dataset):
84+
batches = fixed_local_dataset.read_batches_iter(source="not-gonna-find-me")
85+
assert list(batches) == []
86+
87+
88+
def test_read_dicts_filter_to_none_stopiteration_immediately(fixed_local_dataset):
89+
batches = fixed_local_dataset.read_dicts_iter(source="not-gonna-find-me")
90+
with pytest.raises(StopIteration):
91+
next(batches)

timdex_dataset_api/dataset.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TYPE_CHECKING, TypedDict, Unpack
1111

1212
import boto3
13+
import pandas as pd
1314
import pyarrow as pa
1415
import pyarrow.compute as pc
1516
import pyarrow.dataset as ds
@@ -388,3 +389,82 @@ def log_write_statistics(self, start_time: float) -> None:
388389
f"total rows: {total_rows}, "
389390
f"total size: {total_size}"
390391
)
392+
393+
def read_batches_iter(
394+
self,
395+
columns: list[str] | None = None,
396+
batch_size: int = DEFAULT_BATCH_SIZE,
397+
**filters: Unpack[DatasetFilters],
398+
) -> Iterator[pa.RecordBatch]:
399+
"""Yield pyarrow.RecordBatches from the dataset.
400+
401+
While batch_size will limit the max rows per batch, filtering may result in some
402+
batches have less than this limit.
403+
404+
Args:
405+
- columns: list[str], list of columns to return from the dataset
406+
- batch_size: int, max number of rows to yield per batch
407+
- filter_kwargs: pairs of column:value to filter the dataset
408+
"""
409+
if not self.dataset:
410+
raise DatasetNotLoadedError(
411+
"Dataset is not loaded. Please call the `load` method first."
412+
)
413+
dataset = self._get_filtered_dataset(**filters)
414+
for batch in dataset.to_batches(columns=columns, batch_size=batch_size):
415+
if len(batch) > 0:
416+
yield batch
417+
418+
def read_dataframes_iter(
419+
self,
420+
columns: list[str] | None = None,
421+
batch_size: int = DEFAULT_BATCH_SIZE,
422+
**filters: Unpack[DatasetFilters],
423+
) -> Iterator[pd.DataFrame]:
424+
"""Yield record batches as Pandas DataFrames from the dataset.
425+
426+
Args: see self.read_batches_iter()
427+
"""
428+
for record_batch in self.read_batches_iter(
429+
columns=columns, batch_size=batch_size, **filters
430+
):
431+
yield record_batch.to_pandas()
432+
433+
def read_dataframe(
434+
self,
435+
columns: list[str] | None = None,
436+
batch_size: int = DEFAULT_BATCH_SIZE,
437+
**filters: Unpack[DatasetFilters],
438+
) -> pd.DataFrame | None:
439+
"""Yield record batches as Pandas DataFrames and concatenate to single dataframe.
440+
441+
WARNING: this will pull all records from currently filtered dataset into memory.
442+
443+
If no batches are found based on filtered dataset, None is returned.
444+
445+
Args: see self.read_batches_iter()
446+
"""
447+
df_batches = [
448+
record_batch.to_pandas()
449+
for record_batch in self.read_batches_iter(
450+
columns=columns, batch_size=batch_size, **filters
451+
)
452+
]
453+
if not df_batches:
454+
return None
455+
return pd.concat(df_batches)
456+
457+
def read_dicts_iter(
458+
self,
459+
columns: list[str] | None = None,
460+
batch_size: int = DEFAULT_BATCH_SIZE,
461+
**filters: Unpack[DatasetFilters],
462+
) -> Iterator[dict]:
463+
"""Yield individual record rows as dictionaries from the dataset.
464+
465+
Args: see self.read_batches_iter()
466+
"""
467+
for record_batch in self.read_batches_iter(
468+
columns=columns, batch_size=batch_size, **filters
469+
):
470+
yield from record_batch.to_pylist()

0 commit comments

Comments
 (0)