Skip to content

Commit 1b99132

Browse files
committed
Current records functionality uses TIMDEXDatasetMetadata
Why these changes are being introduced: Previously, when asked to yield only current records the class TIMDEXRunManager was used to crawl the dataset and produce a dataframe of ETL runs that informed a second iterator to consult during the yielding of records. This class was superceded by TIMDEXDatasetMetadata which can provide the same information, but also more granular record level metadata which simplifies the yielding of current records. How this addresses that need: The new TIMDEXDatasetMetadata class is used to provide a mapping of timdex_record_id to run_id that allows, during the yielding of records, to only yield a record if it's associated with the correct run_id. The net effect here is a simpler, and ultimately more accureate, way to yield current records while also utilizing a class that provides more functionality in other contexts. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-507
1 parent 9298e7e commit 1b99132

1 file changed

Lines changed: 42 additions & 63 deletions

File tree

timdex_dataset_api/dataset.py

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from timdex_dataset_api.config import configure_logger
2222
from timdex_dataset_api.exceptions import DatasetNotLoadedError
23-
from timdex_dataset_api.run import TIMDEXRunManager
23+
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
2424

2525
if TYPE_CHECKING:
2626
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
@@ -128,7 +128,7 @@ def __init__(
128128

129129
# reading
130130
self._current_records: bool = False
131-
self._current_records_dataset: ds.Dataset = None # type: ignore[assignment]
131+
self.timdex_dataset_metadata: TIMDEXDatasetMetadata = None # type: ignore[assignment]
132132

133133
@property
134134
def row_count(self) -> int:
@@ -162,31 +162,22 @@ def load(
162162
- Filters passed directly in method call, e.g. source="alma",
163163
run_date="2024-12-20", etc., but are typed according to DatasetFilters.
164164
- current_records: bool
165-
- if True, the TIMDEXRunManager will be used to retrieve a list of parquet
166-
files associated with current runs, some internal flags will be set, all
167-
ensuring that only current records are yielded for any read methods
165+
- if True, all records yielded from this instance will be the current
166+
version of the record in the dataset.
168167
"""
169168
start_time = time.perf_counter()
170169

171170
# reset paths from original location before load
172171
_, self.paths = self.parse_location(self.location)
173172

174-
# perform initial load of full dataset
175-
self.dataset = self._load_pyarrow_dataset()
176-
173+
# read dataset metadata if only current records are requested
177174
self._current_records = current_records
178175
if current_records:
176+
self.timdex_dataset_metadata = TIMDEXDatasetMetadata(timdex_dataset=self)
177+
self.paths = self.timdex_dataset_metadata.get_current_parquet_files(**filters)
179178

180-
timdex_run_manager = TIMDEXRunManager(dataset=self.dataset)
181-
self.paths = timdex_run_manager.get_current_parquet_files(
182-
source=filters.get("source")
183-
)
184-
185-
# reload pyarrow dataset, filtered now to an explicit list of parquet files
186-
# also save an instance of the dataset before any additional filtering
187-
dataset = self._load_pyarrow_dataset()
188-
self.dataset = dataset
189-
self._current_records_dataset = dataset
179+
# perform initial load of full dataset
180+
self.dataset = self._load_pyarrow_dataset()
190181

191182
# filter dataset
192183
self.dataset = self._get_filtered_dataset(**filters)
@@ -476,6 +467,13 @@ def read_batches_iter(
476467
)
477468
dataset = self._get_filtered_dataset(**filters)
478469

470+
# if current records, add required columns for deduplication
471+
if self._current_records:
472+
if not columns:
473+
columns = TIMDEX_DATASET_SCHEMA.names
474+
columns.extend(["timdex_record_id", "run_id"])
475+
columns = list(set(columns))
476+
479477
batches = dataset.to_batches(
480478
columns=columns,
481479
batch_size=self.config.read_batch_size,
@@ -484,7 +482,7 @@ def read_batches_iter(
484482
)
485483

486484
if self._current_records:
487-
yield from self._yield_current_record_batches(batches)
485+
yield from self._yield_current_record_batches(batches, **filters)
488486
else:
489487
for batch in batches:
490488
if len(batch) > 0:
@@ -493,6 +491,7 @@ def read_batches_iter(
493491
def _yield_current_record_batches(
494492
self,
495493
batches: Iterator[pa.RecordBatch],
494+
**filters: Unpack[DatasetFilters],
496495
) -> Iterator[pa.RecordBatch]:
497496
"""Method to yield only the most recent version of each record.
498497
@@ -501,60 +500,40 @@ def _yield_current_record_batches(
501500
applied that removes this most recent version of a record, that timdex_record_id
502501
will not be yielded at all.
503502
504-
This is achieved by iterating over TWO record batch iterators in parallel:
505-
506-
1. "batches" - the RecordBatch iterator passed to this method which
507-
contains the actual records and columns we are interested in, and may contain
508-
filtering
509-
510-
2. "unfiltered_batches" - a lightweight RecordBatch iterator that only
511-
contains the 'timdex_record_id' column from a pre-filtering dataset saved
512-
during .load()
513-
514-
These two iterators are guaranteed to have the same number of total batches based
515-
on how pyarrow.Dataset.to_batches() reads from parquet files. Even if dataset
516-
filtering is applied, this does not affect the batch count; you may just end up
517-
with smaller or empty batches.
518-
519-
As we move through the record batches we use unfiltered batches to keep a list of
520-
seen timdex_record_ids. Even if a timdex_record_is not in the record batch --
521-
likely due to filtering -- we will mark that timdex_record_id as "seen" and not
522-
yield it from any future batches.
503+
# TODO: update docstring....again...
523504
524505
Args:
525506
- batches: batches of records to actually yield from
526-
- current_record_id_batches: batches of timdex_record_id's that inform when
527-
to yield or skip a record for a batch
507+
- filters: pairs of column:value to filter the dataset metadata required
528508
"""
529-
unfiltered_batches = self._current_records_dataset.to_batches(
530-
columns=["timdex_record_id"],
531-
batch_size=self.config.read_batch_size,
532-
batch_readahead=self.config.batch_read_ahead,
533-
fragment_readahead=self.config.fragment_read_ahead,
509+
# get map of timdex_record_id to run_id for current version of that record
510+
record_to_run_map = self.timdex_dataset_metadata.get_current_record_to_run_map(
511+
**filters
534512
)
535513

536-
seen_records = set()
537-
for unfiltered_batch, batch in zip(unfiltered_batches, batches, strict=True):
538-
# init list of indices from the batch for records we have never yielded
539-
unseen_batch_indices = []
514+
# loop through batches, yielding only current records
515+
for batch in batches:
540516

541-
# check each record id and track unseen ones
542-
for i, record_id in enumerate(batch.column("timdex_record_id").to_pylist()):
543-
if record_id not in seen_records:
544-
unseen_batch_indices.append(i)
517+
if batch.num_rows == 0:
518+
continue
545519

546-
# even if not a record to yield, update our list of seen records from all
547-
# records in the unfiltered batch
548-
seen_records.update(unfiltered_batch.column("timdex_record_id").to_pylist())
520+
to_yield_indices = []
549521

550-
# if no unseen records from this batch, skip yielding entirely
551-
if not unseen_batch_indices:
552-
continue
522+
record_ids = batch.column("timdex_record_id").to_pylist()
523+
run_ids = batch.column("run_id").to_pylist()
524+
525+
for i, (record_id, run_id) in enumerate(
526+
zip(
527+
record_ids,
528+
run_ids,
529+
strict=True,
530+
)
531+
):
532+
if record_to_run_map.get(record_id) == run_id:
533+
to_yield_indices.append(i)
553534

554-
# create a new RecordBatch using the unseen indices of the batch
555-
_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
556-
if len(_batch) > 0:
557-
yield _batch
535+
if to_yield_indices:
536+
yield batch.take(pa.array(to_yield_indices)) # type: ignore[arg-type]
558537

559538
def read_dataframes_iter(
560539
self,

0 commit comments

Comments
 (0)