Skip to content

Commit 7e0e795

Browse files
committed
TIMDEXDataset capable of yielding current records only
Why these changes are being introduced: With TIMDEXDataset capable of limiting to only parquet files associated with current runs, the next logical step is providing the ability to yield only the current version of a record. This would support a "full refresh" of a TIMDEX source where an application like TIM could yield only current records for a given source and index those to Opensearch. How this addresses that need: When TIMDEXDataset is loaded with current_records=True, the private attribute TIMDEXDataset._dedupe_on_read is set to True, informing any read methods to dedupe during yielding. Because all read methods TIMDEXDataset.read_batches_iter() at the lowest level, the deduping logic is required only there. Because the ordering of the parquet files is already handled by the load method, the read methods can be confident they are always seeing the most recent version of a record first, and thus can just maintain a "seen" list as they are encountered. This keeps the deduplication effectively instant and memory safe; no large in-memory reordering or deduplication is required. Side effects of this change: * Applications like TIM now have the option of yielding only current records for a source, or all sources, supporting new functionality like fully reindexing a source in Opensearch from parquet dataset data alone. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-494
1 parent 50fff12 commit 7e0e795

3 files changed

Lines changed: 110 additions & 6 deletions

File tree

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def dataset_with_runs_location(tmp_path) -> str:
138138
num_records, source, run_date, run_type, action, run_id = params
139139
records = generate_sample_records(
140140
num_records,
141+
timdex_record_id_prefix=source,
141142
source=source,
142143
run_date=run_date,
143144
run_type=run_type,
@@ -147,3 +148,8 @@ def dataset_with_runs_location(tmp_path) -> str:
147148
timdex_dataset.write(records)
148149

149150
return location
151+
152+
153+
@pytest.fixture
154+
def local_dataset_with_runs(dataset_with_runs_location) -> TIMDEXDataset:
155+
return TIMDEXDataset(dataset_with_runs_location)

tests/test_dataset.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,57 @@ def test_dataset_local_dataset_row_count_missing_dataset_raise_error(local_datas
339339
td = TIMDEXDataset(location="path/to/nowhere")
340340
with pytest.raises(DatasetNotLoadedError):
341341
_ = td.row_count
342+
343+
344+
def test_dataset_all_records_not_current_and_not_deduped(local_dataset_with_runs):
345+
local_dataset_with_runs.load()
346+
all_records_df = local_dataset_with_runs.read_dataframe()
347+
348+
# assert counts reflect all records from dataset, no deduping
349+
assert all_records_df.source.value_counts().to_dict() == {"alma": 254, "dspace": 194}
350+
351+
# assert run_date min/max dates align with min/max for all runs
352+
assert all_records_df.run_date.min() == date(2024, 12, 1)
353+
assert all_records_df.run_date.max() == date(2025, 2, 5)
354+
355+
356+
def test_dataset_all_current_records_deduped(local_dataset_with_runs):
357+
local_dataset_with_runs.load(current_records=True)
358+
all_records_df = local_dataset_with_runs.read_dataframe()
359+
360+
# assert both sources have accurate record counts for current records only
361+
assert all_records_df.source.value_counts().to_dict() == {"dspace": 90, "alma": 100}
362+
363+
# assert only one "full" run, per source
364+
assert len(all_records_df[all_records_df.run_type == "full"].run_id.unique()) == 2
365+
366+
# assert run_date min/max dates align with both sources min/max dates
367+
assert all_records_df.run_date.min() == date(2025, 1, 1) # both
368+
assert all_records_df.run_date.max() == date(2025, 2, 5) # dspace
369+
370+
371+
def test_dataset_source_current_records_deduped(local_dataset_with_runs):
372+
local_dataset_with_runs.load(current_records=True, source="alma")
373+
alma_records_df = local_dataset_with_runs.read_dataframe()
374+
375+
# assert only alma records present and correct count
376+
assert alma_records_df.source.value_counts().to_dict() == {"alma": 100}
377+
378+
# assert only one "full" run
379+
assert len(alma_records_df[alma_records_df.run_type == "full"].run_id.unique()) == 1
380+
381+
# assert run_date min/max dates are correct for single source
382+
assert alma_records_df.run_date.min() == date(2025, 1, 1)
383+
assert alma_records_df.run_date.max() == date(2025, 1, 5)
384+
385+
386+
def test_dataset_all_read_methods_get_deduplication(
387+
local_dataset_with_runs,
388+
):
389+
local_dataset_with_runs.load(current_records=True, source="alma")
390+
391+
full_df = local_dataset_with_runs.read_dataframe()
392+
all_records = list(local_dataset_with_runs.read_dicts_iter())
393+
transformed_records = list(local_dataset_with_runs.read_transformed_records_iter())
394+
395+
assert len(full_df) == len(all_records) == len(transformed_records)

timdex_dataset_api/dataset.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
self.schema = TIMDEX_DATASET_SCHEMA
121121
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
122122
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
123+
self._dedupe_on_read: bool = False
123124

124125
@property
125126
def row_count(self) -> int:
@@ -162,6 +163,7 @@ def load(
162163
self._load_pyarrow_dataset()
163164

164165
# if current_records flag set, limit to parquet files associated with current runs
166+
self._dedupe_on_read = current_records
165167
if current_records:
166168
timdex_run_manager = TIMDEXRunManager(timdex_dataset=self)
167169

@@ -467,14 +469,55 @@ def read_batches_iter(
467469
"Dataset is not loaded. Please call the `load` method first."
468470
)
469471
dataset = self._get_filtered_dataset(**filters)
470-
for batch in dataset.to_batches(
472+
473+
batches = dataset.to_batches(
471474
columns=columns,
472475
batch_size=self.config.read_batch_size,
473476
batch_readahead=self.config.batch_read_ahead,
474477
fragment_readahead=self.config.fragment_read_ahead,
475-
):
478+
)
479+
480+
if self._dedupe_on_read:
481+
yield from self._yield_deduped_batches(batches)
482+
else:
483+
for batch in batches:
484+
if len(batch) > 0:
485+
yield batch
486+
487+
def _yield_deduped_batches(
488+
self, batches: Iterator[pa.RecordBatch]
489+
) -> Iterator[pa.RecordBatch]:
490+
"""Method to yield record deduped batches.
491+
492+
Extending the normal behavior of yielding batches untouched, this method keeps
493+
track of seen timdex_record_id's, yielding them only once. For this method to
494+
yield the most current version of a record -- most common usage -- it is required
495+
that the batches are pre-ordered so the most recent record version is encountered
496+
first.
497+
"""
498+
seen_records = set()
499+
for batch in batches:
476500
if len(batch) > 0:
477-
yield batch
501+
# init list of batch indices for records unseen
502+
unseen_batch_indices = []
503+
504+
# get list of timdex ids from batch
505+
timdex_ids = batch.column("timdex_record_id").to_pylist()
506+
507+
# check each record id and track unseen ones
508+
for i, record_id in enumerate(timdex_ids):
509+
if record_id not in seen_records:
510+
unseen_batch_indices.append(i)
511+
seen_records.add(record_id)
512+
513+
# if all records from batch were seen, continue
514+
if not unseen_batch_indices:
515+
continue
516+
517+
# else, yield unseen records from batch
518+
deduped_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
519+
if len(deduped_batch) > 0:
520+
yield deduped_batch
478521

479522
def read_dataframes_iter(
480523
self,
@@ -536,13 +579,14 @@ def read_transformed_records_iter(
536579
) -> Iterator[dict]:
537580
"""Yield individual transformed records as dictionaries from the dataset.
538581
539-
If 'transformed_record' is None (i.e., action="skip"|"error"), the yield
540-
statement will not be executed for the row.
582+
If 'transformed_record' is None (common scenarios are action="skip"|"error"), the
583+
yield statement will not be executed for the row. Note that for action="delete" a
584+
transformed record still may be yielded if present.
541585
542586
Args: see self.read_batches_iter()
543587
"""
544588
for record_dict in self.read_dicts_iter(
545-
columns=["transformed_record"],
589+
columns=["timdex_record_id", "transformed_record"],
546590
**filters,
547591
):
548592
if transformed_record := record_dict["transformed_record"]:

0 commit comments

Comments
 (0)