Skip to content

Commit 00b8d2a

Browse files
authored
Merge pull request #144 from MITLibraries/TIMX-494-source-current-runs-and-records
TIMX 494 - yield deduped, most recent records
2 parents 0004a65 + 66477f7 commit 00b8d2a

6 files changed

Lines changed: 202 additions & 27 deletions

File tree

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def dataset_with_runs_location(tmp_path) -> str:
138138
num_records, source, run_date, run_type, action, run_id = params
139139
records = generate_sample_records(
140140
num_records,
141+
timdex_record_id_prefix=source,
141142
source=source,
142143
run_date=run_date,
143144
run_type=run_type,
@@ -147,3 +148,8 @@ def dataset_with_runs_location(tmp_path) -> str:
147148
timdex_dataset.write(records)
148149

149150
return location
151+
152+
153+
@pytest.fixture
154+
def local_dataset_with_runs(dataset_with_runs_location) -> TIMDEXDataset:
155+
return TIMDEXDataset(dataset_with_runs_location)

tests/test_dataset.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def test_dataset_init_success(location, expected_file_system, expected_source):
2525
timdex_dataset = TIMDEXDataset(location=location)
2626
assert isinstance(timdex_dataset.filesystem, expected_file_system)
27-
assert timdex_dataset.source == expected_source
27+
assert timdex_dataset.paths == expected_source
2828

2929

3030
def test_dataset_init_env_vars_set_config(monkeypatch, local_dataset_location):
@@ -79,8 +79,7 @@ def test_dataset_load_s3_sets_filesystem_and_dataset_success(
7979
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
8080
result = timdex_dataset.load()
8181

82-
mock_get_s3_fs.assert_called_once()
83-
mock_pyarrow_ds.assert_called_once_with(
82+
mock_pyarrow_ds.assert_called_with(
8483
"bucket/path/to/dataset",
8584
schema=timdex_dataset.schema,
8685
format="parquet",
@@ -137,6 +136,26 @@ def test_dataset_load_with_multi_nonpartition_filters_success(fixed_local_datase
137136
assert fixed_local_dataset.row_count == 1
138137

139138

139+
def test_dataset_load_current_records_all_sources_success(dataset_with_runs_location):
140+
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
141+
142+
# 16 total parquet files, with current_records=False we get them all
143+
timdex_dataset.load(current_records=False)
144+
assert len(timdex_dataset.dataset.files) == 16
145+
146+
# 16 total parquet files, with current_records=True we only get 12 for current runs
147+
timdex_dataset.load(current_records=True)
148+
assert len(timdex_dataset.dataset.files) == 12
149+
150+
151+
def test_dataset_load_current_records_one_source_success(dataset_with_runs_location):
152+
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
153+
timdex_dataset.load(current_records=True, source="alma")
154+
155+
# 7 total parquet files for source, only 6 related to current runs
156+
assert len(timdex_dataset.dataset.files) == 6
157+
158+
140159
def test_dataset_get_filtered_dataset_with_single_nonpartition_success(
141160
fixed_local_dataset,
142161
):
@@ -324,3 +343,57 @@ def test_dataset_local_dataset_row_count_missing_dataset_raise_error(local_datas
324343
td = TIMDEXDataset(location="path/to/nowhere")
325344
with pytest.raises(DatasetNotLoadedError):
326345
_ = td.row_count
346+
347+
348+
def test_dataset_all_records_not_current_and_not_deduped(local_dataset_with_runs):
349+
local_dataset_with_runs.load()
350+
all_records_df = local_dataset_with_runs.read_dataframe()
351+
352+
# assert counts reflect all records from dataset, no deduping
353+
assert all_records_df.source.value_counts().to_dict() == {"alma": 254, "dspace": 194}
354+
355+
# assert run_date min/max dates align with min/max for all runs
356+
assert all_records_df.run_date.min() == date(2024, 12, 1)
357+
assert all_records_df.run_date.max() == date(2025, 2, 5)
358+
359+
360+
def test_dataset_all_current_records_deduped(local_dataset_with_runs):
361+
local_dataset_with_runs.load(current_records=True)
362+
all_records_df = local_dataset_with_runs.read_dataframe()
363+
364+
# assert both sources have accurate record counts for current records only
365+
assert all_records_df.source.value_counts().to_dict() == {"dspace": 90, "alma": 100}
366+
367+
# assert only one "full" run, per source
368+
assert len(all_records_df[all_records_df.run_type == "full"].run_id.unique()) == 2
369+
370+
# assert run_date min/max dates align with both sources min/max dates
371+
assert all_records_df.run_date.min() == date(2025, 1, 1) # both
372+
assert all_records_df.run_date.max() == date(2025, 2, 5) # dspace
373+
374+
375+
def test_dataset_source_current_records_deduped(local_dataset_with_runs):
376+
local_dataset_with_runs.load(current_records=True, source="alma")
377+
alma_records_df = local_dataset_with_runs.read_dataframe()
378+
379+
# assert only alma records present and correct count
380+
assert alma_records_df.source.value_counts().to_dict() == {"alma": 100}
381+
382+
# assert only one "full" run
383+
assert len(alma_records_df[alma_records_df.run_type == "full"].run_id.unique()) == 1
384+
385+
# assert run_date min/max dates are correct for single source
386+
assert alma_records_df.run_date.min() == date(2025, 1, 1)
387+
assert alma_records_df.run_date.max() == date(2025, 1, 5)
388+
389+
390+
def test_dataset_all_read_methods_get_deduplication(
391+
local_dataset_with_runs,
392+
):
393+
local_dataset_with_runs.load(current_records=True, source="alma")
394+
395+
full_df = local_dataset_with_runs.read_dataframe()
396+
all_records = list(local_dataset_with_runs.read_dicts_iter())
397+
transformed_records = list(local_dataset_with_runs.read_transformed_records_iter())
398+
399+
assert len(full_df) == len(all_records) == len(transformed_records)

tests/test_runs.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,27 @@ def test_timdex_run_manager_get_runs_df(timdex_run_manager):
5656
assert runs_df.source.value_counts().to_dict() == {"alma": 7, "dspace": 7}
5757

5858

59+
def test_timdex_run_manager_get_all_current_run_parquet_files_success(
60+
timdex_run_manager,
61+
):
62+
ordered_parquet_files = timdex_run_manager.get_current_parquet_files()
63+
64+
# assert 12 parquet files, despite being 14 total for ALL sources
65+
# this represents the last full run and all daily since
66+
assert len(ordered_parquet_files) == 12
67+
68+
# assert sorted reverse chronologically
69+
assert "year=2025/month=01/day=01" in ordered_parquet_files[-1]
70+
71+
5972
def test_timdex_run_manager_get_source_current_run_parquet_files_success(
6073
timdex_run_manager,
6174
):
62-
ordered_parquet_files = timdex_run_manager.get_current_source_parquet_files("alma")
75+
ordered_parquet_files = timdex_run_manager._get_current_source_parquet_files("alma")
6376

64-
# assert 6 parquet files, despite being 8 total for alma
77+
# assert 6 parquet files, despite being 8 total for 'alma' source
6578
# this represents the last full run and all daily since
66-
assert len(ordered_parquet_files)
79+
assert len(ordered_parquet_files) == 6
6780

6881
# assert sorted reverse chronologically
6982
assert "year=2025/month=01/day=05" in ordered_parquet_files[0]

timdex_dataset_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from timdex_dataset_api.dataset import TIMDEXDataset
44
from timdex_dataset_api.record import DatasetRecord
55

6-
__version__ = "1.0.0"
6+
__version__ = "2.0.0"
77

88
__all__ = [
99
"DatasetRecord",

timdex_dataset_api/dataset.py

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from timdex_dataset_api.config import configure_logger
2222
from timdex_dataset_api.exceptions import DatasetNotLoadedError
23+
from timdex_dataset_api.run import TIMDEXRunManager
2324

2425
if TYPE_CHECKING:
2526
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
@@ -114,11 +115,12 @@ def __init__(
114115
self.location = location
115116
self.config = config or TIMDEXDatasetConfig()
116117

117-
self.filesystem, self.source = self.parse_location(self.location)
118+
self.filesystem, self.paths = self.parse_location(self.location)
118119
self.dataset: ds.Dataset = None # type: ignore[assignment]
119120
self.schema = TIMDEX_DATASET_SCHEMA
120121
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
121122
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
123+
self._dedupe_on_read: bool = False
122124

123125
@property
124126
def row_count(self) -> int:
@@ -129,6 +131,8 @@ def row_count(self) -> int:
129131

130132
def load(
131133
self,
134+
*,
135+
current_records: bool = False,
132136
**filters: Unpack[DatasetFilters],
133137
) -> None:
134138
"""Lazy load a pyarrow.dataset.Dataset and set to self.dataset.
@@ -152,14 +156,24 @@ def load(
152156
"""
153157
start_time = time.perf_counter()
154158

155-
# load dataset
156-
self.dataset = ds.dataset(
157-
self.source,
158-
schema=self.schema,
159-
format="parquet",
160-
partitioning="hive",
161-
filesystem=self.filesystem,
162-
)
159+
# reset paths from original location before load
160+
_, self.paths = self.parse_location(self.location)
161+
162+
# perform initial load of full dataset
163+
self._load_pyarrow_dataset()
164+
165+
# if current_records flag set, limit to parquet files associated with current runs
166+
self._dedupe_on_read = current_records
167+
if current_records:
168+
timdex_run_manager = TIMDEXRunManager(timdex_dataset=self)
169+
170+
# update paths, limiting by source if set
171+
self.paths = timdex_run_manager.get_current_parquet_files(
172+
source=filters.get("source")
173+
)
174+
175+
# reload pyarrow dataset
176+
self._load_pyarrow_dataset()
163177

164178
# filter dataset
165179
self.dataset = self._get_filtered_dataset(**filters)
@@ -169,6 +183,16 @@ def load(
169183
f"{round(time.perf_counter()-start_time, 2)}s"
170184
)
171185

186+
def _load_pyarrow_dataset(self) -> None:
187+
"""Load the pyarrow dataset per local filesystem and paths attributes."""
188+
self.dataset = ds.dataset(
189+
self.paths,
190+
schema=self.schema,
191+
format="parquet",
192+
partitioning="hive",
193+
filesystem=self.filesystem,
194+
)
195+
172196
def _get_filtered_dataset(
173197
self,
174198
**filters: Unpack[DatasetFilters],
@@ -345,7 +369,8 @@ def write(
345369
start_time = time.perf_counter()
346370
self._written_files = []
347371

348-
if isinstance(self.source, list):
372+
dataset_filesystem, dataset_path = self.parse_location(self.location)
373+
if isinstance(dataset_path, list):
349374
raise TypeError(
350375
"Dataset location must be the root of a single dataset for writing"
351376
)
@@ -354,10 +379,10 @@ def write(
354379

355380
ds.write_dataset(
356381
record_batches_iter,
357-
base_dir=self.source,
382+
base_dir=dataset_path,
358383
basename_template="%s-{i}.parquet" % (str(uuid.uuid4())), # noqa: UP031
359384
existing_data_behavior="overwrite_or_ignore",
360-
filesystem=self.filesystem,
385+
filesystem=dataset_filesystem,
361386
file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type]
362387
format="parquet",
363388
max_open_files=500,
@@ -444,14 +469,55 @@ def read_batches_iter(
444469
"Dataset is not loaded. Please call the `load` method first."
445470
)
446471
dataset = self._get_filtered_dataset(**filters)
447-
for batch in dataset.to_batches(
472+
473+
batches = dataset.to_batches(
448474
columns=columns,
449475
batch_size=self.config.read_batch_size,
450476
batch_readahead=self.config.batch_read_ahead,
451477
fragment_readahead=self.config.fragment_read_ahead,
452-
):
478+
)
479+
480+
if self._dedupe_on_read:
481+
yield from self._yield_deduped_batches(batches)
482+
else:
483+
for batch in batches:
484+
if len(batch) > 0:
485+
yield batch
486+
487+
def _yield_deduped_batches(
488+
self, batches: Iterator[pa.RecordBatch]
489+
) -> Iterator[pa.RecordBatch]:
490+
"""Method to yield record deduped batches.
491+
492+
Extending the normal behavior of yielding batches untouched, this method keeps
493+
track of seen timdex_record_id's, yielding them only once. For this method to
494+
yield the most current version of a record -- most common usage -- it is required
495+
that the batches are pre-ordered so the most recent record version is encountered
496+
first.
497+
"""
498+
seen_records = set()
499+
for batch in batches:
453500
if len(batch) > 0:
454-
yield batch
501+
# init list of batch indices for records unseen
502+
unseen_batch_indices = []
503+
504+
# get list of timdex ids from batch
505+
timdex_ids = batch.column("timdex_record_id").to_pylist()
506+
507+
# check each record id and track unseen ones
508+
for i, record_id in enumerate(timdex_ids):
509+
if record_id not in seen_records:
510+
unseen_batch_indices.append(i)
511+
seen_records.add(record_id)
512+
513+
# if all records from batch were seen, continue
514+
if not unseen_batch_indices:
515+
continue
516+
517+
# else, yield unseen records from batch
518+
deduped_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
519+
if len(deduped_batch) > 0:
520+
yield deduped_batch
455521

456522
def read_dataframes_iter(
457523
self,
@@ -513,13 +579,14 @@ def read_transformed_records_iter(
513579
) -> Iterator[dict]:
514580
"""Yield individual transformed records as dictionaries from the dataset.
515581
516-
If 'transformed_record' is None (i.e., action="skip"|"error"), the yield
517-
statement will not be executed for the row.
582+
If 'transformed_record' is None (common scenarios are action="skip"|"error"), the
583+
yield statement will not be executed for the row. Note that for action="delete" a
584+
transformed record still may be yielded if present.
518585
519586
Args: see self.read_batches_iter()
520587
"""
521588
for record_dict in self.read_dicts_iter(
522-
columns=["transformed_record"],
589+
columns=["timdex_record_id", "transformed_record"],
523590
**filters,
524591
):
525592
if transformed_record := record_dict["transformed_record"]:

timdex_dataset_api/run.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,22 @@ def get_runs_metadata(self, *, refresh: bool = False) -> pd.DataFrame:
8585
)
8686
return grouped_runs_df
8787

88-
def get_current_source_parquet_files(self, source: str) -> list[str]:
88+
def get_current_parquet_files(self, source: str | None = None) -> list[str]:
89+
"""Get reverse chronological list of parquet files associated with current runs.
90+
91+
Args:
92+
source: if provided, limits parquet files to only that source
93+
"""
94+
runs_df = self.get_runs_metadata() # run metadata is cached for future calls
95+
sources = [source] if source else list(runs_df.source.unique())
96+
97+
source_parquet_files = []
98+
for _source in sources:
99+
source_parquet_files.extend(self._get_current_source_parquet_files(_source))
100+
101+
return source_parquet_files
102+
103+
def _get_current_source_parquet_files(self, source: str) -> list[str]:
89104
"""Get reverse chronological list of current parquet files for a source.
90105
91106
Args:
@@ -166,8 +181,9 @@ def _parse_run_metadata_from_parquet_file(self, parquet_filepath: str) -> dict:
166181
"""
167182
parquet_file = pq.ParquetFile(
168183
parquet_filepath,
169-
filesystem=self.timdex_dataset.filesystem, # type: ignore[union-attr]
184+
filesystem=self.timdex_dataset.filesystem,
170185
)
186+
171187
file_meta = parquet_file.metadata.to_dict()
172188
num_rows = file_meta["num_rows"]
173189
columns_meta = file_meta["row_groups"][0]["columns"] # type: ignore[typeddict-item]

0 commit comments

Comments
 (0)