Skip to content

Commit ed2ce7b

Browse files
committed
Support filtering for current_records
Why these changes are being introduced: Unexpected behavior was possible when using load(current_records=True) and then applying additional filtering to the dataset before reading. In short, a non-current record could be yielded if filtering removed the truly current version of the record. This happened because the reverse chronological marking of "seen" records would not "see" this record and happily yield an older version. How this addresses that need: When load(current_records=True) is used, a clone of the dataset is saved to the TIMDEXDataset object before any additional filtering is applied. This dataset is just metadata, not expensive to store. Then, during any read methods, this dataset is used to provide an exhaustive and ordered list of timdex_record_ids. Even if a record has been filtered out by the read method (e.g. limiting records to only action="index"), this secondary list of timdex_record_ids is used as the authoritative list of "seen" timdex_record_ids. There is a bit of network overhead to this parallel batch reading, but fairly minimal as we are only retrieving the 'timdex_record_id'; perhaps 1-2mb of IO per millions of records. Side effects of this change: * Applications like TIM that will likely use this new functionality to yield only "current" records can do so confidently, and optionally with additional filtering, knowing they will only encounter current versions of a record from the dataset. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-497
1 parent 9f0d74b commit ed2ce7b

2 files changed

Lines changed: 148 additions & 52 deletions

File tree

tests/test_dataset.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# ruff: noqa: S105, S106, SLF001, PLR2004
1+
# ruff: noqa: D205, S105, S106, SLF001, PD901, PLR2004
2+
23
import os
34
from datetime import date
45
from unittest.mock import MagicMock, patch
@@ -397,3 +398,61 @@ def test_dataset_all_read_methods_get_deduplication(
397398
transformed_records = list(local_dataset_with_runs.read_transformed_records_iter())
398399

399400
assert len(full_df) == len(all_records) == len(transformed_records)
401+
402+
403+
def test_dataset_current_records_no_additional_filtering_accurate_records_yielded(
404+
local_dataset_with_runs,
405+
):
406+
local_dataset_with_runs.load(current_records=True, source="alma")
407+
df = local_dataset_with_runs.read_dataframe()
408+
assert df.action.value_counts().to_dict() == {"index": 99, "delete": 1}
409+
410+
411+
def test_dataset_current_records_action_filtering_accurate_records_yielded(
412+
local_dataset_with_runs,
413+
):
414+
local_dataset_with_runs.load(current_records=True, source="alma")
415+
df = local_dataset_with_runs.read_dataframe(action="index")
416+
assert df.action.value_counts().to_dict() == {"index": 99}
417+
418+
419+
def test_dataset_current_records_index_filtering_accurate_records_yielded(
420+
local_dataset_with_runs,
421+
):
422+
"""This is a somewhat complex test, but demonstrates that only 'current' records
423+
are yielded when .load(current_records=True) is applied.
424+
425+
Given these runs from the fixture:
426+
[
427+
...
428+
(25, "alma", "2025-01-03", "daily", "index", "run-5"), <---- filtered to
429+
(10, "alma", "2025-01-04", "daily", "delete", "run-6"), <---- influences current
430+
...
431+
]
432+
433+
Though we are filtering to run-5, which has 25 total records to-index, we see only 15
434+
records yielded. Why? This is because while we have filtered to only yield from
435+
run-5, run-6 had 10 deletes which made records alma:0|9 no longer "current" in run-5.
436+
As we yielded records reverse chronologically, the deletes from run-6 (alma:0-alma:9)
437+
"influenced" what records we would see as we continue backwards in time.
438+
"""
439+
local_dataset_with_runs.load(current_records=True, source="alma")
440+
df = local_dataset_with_runs.read_dataframe(run_id="run-5")
441+
assert df.action.value_counts().to_dict() == {"index": 15}
442+
assert list(df.timdex_record_id) == [
443+
"alma:10",
444+
"alma:11",
445+
"alma:12",
446+
"alma:13",
447+
"alma:14",
448+
"alma:15",
449+
"alma:16",
450+
"alma:17",
451+
"alma:18",
452+
"alma:19",
453+
"alma:20",
454+
"alma:21",
455+
"alma:22",
456+
"alma:23",
457+
"alma:24",
458+
]

timdex_dataset_api/dataset.py

Lines changed: 88 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def __init__(
120120
self.schema = TIMDEX_DATASET_SCHEMA
121121
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
122122
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
123-
self._dedupe_on_read: bool = False
123+
124+
self._current_records: bool = False
125+
self._current_records_dataset: ds.Dataset = None # type: ignore[assignment]
124126

125127
@property
126128
def row_count(self) -> int:
@@ -153,27 +155,32 @@ def load(
153155
- filters: kwargs typed via DatasetFilters TypedDict
154156
- Filters passed directly in method call, e.g. source="alma",
155157
run_date="2024-12-20", etc., but are typed according to DatasetFilters.
158+
- current_records: bool
159+
- if True, the TIMDEXRunManager will be used to retrieve a list of parquet
160+
files associated with current runs, some internal flags will be set, all
161+
ensuring that only current records are yielded for any read methods
156162
"""
157163
start_time = time.perf_counter()
158164

159165
# reset paths from original location before load
160166
_, self.paths = self.parse_location(self.location)
161167

162168
# perform initial load of full dataset
163-
self._load_pyarrow_dataset()
169+
self.dataset = self._load_pyarrow_dataset()
164170

165-
# if current_records flag set, limit to parquet files associated with current runs
166-
self._dedupe_on_read = current_records
171+
self._current_records = current_records
167172
if current_records:
168-
timdex_run_manager = TIMDEXRunManager(timdex_dataset=self)
169173

170174
timdex_run_manager = TIMDEXRunManager(dataset=self.dataset)
171175
self.paths = timdex_run_manager.get_current_parquet_files(
172176
source=filters.get("source")
173177
)
174178

175-
# reload pyarrow dataset
176-
self._load_pyarrow_dataset()
179+
# reload pyarrow dataset, filtered now to an explicit list of parquet files
180+
# also save an instance of the dataset before any additional filtering
181+
dataset = self._load_pyarrow_dataset()
182+
self.dataset = dataset
183+
self._current_records_dataset = dataset
177184

178185
# filter dataset
179186
self.dataset = self._get_filtered_dataset(**filters)
@@ -183,9 +190,9 @@ def load(
183190
f"{round(time.perf_counter()-start_time, 2)}s"
184191
)
185192

186-
def _load_pyarrow_dataset(self) -> None:
193+
def _load_pyarrow_dataset(self) -> ds.Dataset:
187194
"""Load the pyarrow dataset per local filesystem and paths attributes."""
188-
self.dataset = ds.dataset(
195+
return ds.dataset(
189196
self.paths,
190197
schema=self.schema,
191198
format="parquet",
@@ -449,19 +456,14 @@ def read_batches_iter(
449456
"""Yield pyarrow.RecordBatches from the dataset.
450457
451458
While batch_size will limit the max rows per batch, filtering may result in some
452-
batches have less than this limit.
459+
batches having less than this limit.
460+
461+
If the flag self._current_records is set, this method leans on
462+
self._yield_current_record_deduped_batches() to apply deduplication of records to
463+
ensure only current versions of the record are ever yielded.
453464
454465
Args:
455466
- columns: list[str], list of columns to return from the dataset
456-
- batch_size: int, max number of rows to yield per batch
457-
- batch_read_ahead: int, the number of batches to read ahead in a file. This
458-
might not work for all file formats. Increasing this number will increase
459-
RAM usage but could also improve IO utilization. Pyarrow default is 16,
460-
but this library defaults to 0 to prioritize memory footprint.
461-
- fragment_read_ahead: int, The number of files to read ahead. Increasing this
462-
number will increase RAM usage but could also improve IO utilization.
463-
Pyarrow default is 4, but this library defaults to 0 to prioritize memory
464-
footprint.
465467
- filters: pairs of column:value to filter the dataset
466468
"""
467469
if not self.dataset:
@@ -477,47 +479,82 @@ def read_batches_iter(
477479
fragment_readahead=self.config.fragment_read_ahead,
478480
)
479481

480-
if self._dedupe_on_read:
481-
yield from self._yield_deduped_batches(batches)
482+
if self._current_records:
483+
yield from self._yield_current_record_deduped_batches(batches)
482484
else:
483485
for batch in batches:
484486
if len(batch) > 0:
485487
yield batch
486488

487-
def _yield_deduped_batches(
488-
self, batches: Iterator[pa.RecordBatch]
489+
def _yield_current_record_deduped_batches(
490+
self,
491+
batches: Iterator[pa.RecordBatch],
489492
) -> Iterator[pa.RecordBatch]:
490-
"""Method to yield record deduped batches.
493+
"""Method to yield only the most recent version of each record.
494+
495+
When multiple versions of a record (same timdex_record_id) exist in the dataset,
496+
this method ensures only the most recent version is returned. If filtering is
497+
applied that removes this most recent version of a record, that timdex_record_id
498+
will not be yielded at all.
499+
500+
This is achieved by iterating over TWO record batch iterators in parallel:
501+
502+
1. "batches" - the RecordBatch iterator passed to this method which
503+
contains the actual records and columns we are interested in, and may contain
504+
filtering
505+
506+
2. "id_batches" - a lightweight RecordBatch iterator that only contains the
507+
'timdex_record_id' column from a pre-filtering dataset saved during .load()
508+
509+
These two iterators are guaranteed to have the same number of total batches based
510+
on how pyarrow.Dataset.to_batches() reads from parquet files. Even if dataset
511+
filtering is applied, this does not affect the batch count; you may just end up
512+
with smaller or empty batches.
491513
492-
Extending the normal behavior of yielding batches untouched, this method keeps
493-
track of seen timdex_record_id's, yielding them only once. For this method to
494-
yield the most current version of a record -- most common usage -- it is required
495-
that the batches are pre-ordered so the most recent record version is encountered
496-
first.
514+
As such, as we move through the batches we use batches from the "ids_iterator" to
515+
keep a list of seen timdex_record_id's. Even if a timdex_record_is not in the
516+
"records_iterator", likely due to filtering, we will mark the truly most current
517+
version as "seen" and not yield it from any future batches.
518+
519+
Args:
520+
- batches: batches of records to actually yield from
521+
- current_record_id_batches: batches of timdex_record_id's that inform when
522+
to yield or skip a record for a batch
497523
"""
524+
# create a RecordBatch iterator from self._current_records_dataset, which was
525+
# saved during .load() before any filtering was applied
526+
id_batches = self._current_records_dataset.to_batches(
527+
columns=["timdex_record_id"],
528+
batch_size=self.config.read_batch_size,
529+
batch_readahead=self.config.batch_read_ahead,
530+
fragment_readahead=self.config.fragment_read_ahead,
531+
)
532+
498533
seen_records = set()
499-
for batch in batches:
500-
if len(batch) > 0:
501-
# init list of batch indices for records unseen
502-
unseen_batch_indices = []
503-
504-
# get list of timdex ids from batch
505-
timdex_ids = batch.column("timdex_record_id").to_pylist()
506-
507-
# check each record id and track unseen ones
508-
for i, record_id in enumerate(timdex_ids):
509-
if record_id not in seen_records:
510-
unseen_batch_indices.append(i)
511-
seen_records.add(record_id)
512-
513-
# if all records from batch were seen, continue
514-
if not unseen_batch_indices:
515-
continue
516-
517-
# else, yield unseen records from batch
518-
deduped_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
519-
if len(deduped_batch) > 0:
520-
yield deduped_batch
534+
for id_batch, batch in zip(id_batches, batches, strict=True):
535+
dedupe_ids = id_batch.column("timdex_record_id").to_pylist()
536+
batch_ids = batch.column("timdex_record_id").to_pylist()
537+
538+
# init list of indices from the batch for records we have never yielded
539+
unseen_batch_indices = []
540+
541+
# check each record id and track unseen ones
542+
for i, record_id in enumerate(batch_ids):
543+
if record_id not in seen_records:
544+
unseen_batch_indices.append(i)
545+
546+
# even if not a record to yield, update our list of seen records from all
547+
# records in the id_batch
548+
seen_records.update(dedupe_ids)
549+
550+
# if no records unseen this batch, skip yielding
551+
if not unseen_batch_indices:
552+
continue
553+
554+
# use the unseen indices to create a new, subset of the batch and yield it
555+
_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
556+
if len(_batch) > 0:
557+
yield _batch
521558

522559
def read_dataframes_iter(
523560
self,

0 commit comments

Comments
 (0)