2020
2121from timdex_dataset_api .config import configure_logger
2222from timdex_dataset_api .exceptions import DatasetNotLoadedError
23+ from timdex_dataset_api .run import TIMDEXRunManager
2324
2425if TYPE_CHECKING :
2526 from timdex_dataset_api .record import DatasetRecord # pragma: nocover
@@ -114,11 +115,12 @@ def __init__(
114115 self .location = location
115116 self .config = config or TIMDEXDatasetConfig ()
116117
117- self .filesystem , self .source = self .parse_location (self .location )
118+ self .filesystem , self .paths = self .parse_location (self .location )
118119 self .dataset : ds .Dataset = None # type: ignore[assignment]
119120 self .schema = TIMDEX_DATASET_SCHEMA
120121 self .partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
121122 self ._written_files : list [ds .WrittenFile ] = None # type: ignore[assignment]
123+ self ._dedupe_on_read : bool = False
122124
123125 @property
124126 def row_count (self ) -> int :
@@ -129,6 +131,8 @@ def row_count(self) -> int:
129131
130132 def load (
131133 self ,
134+ * ,
135+ current_records : bool = False ,
132136 ** filters : Unpack [DatasetFilters ],
133137 ) -> None :
134138 """Lazy load a pyarrow.dataset.Dataset and set to self.dataset.
@@ -152,14 +156,24 @@ def load(
152156 """
153157 start_time = time .perf_counter ()
154158
155- # load dataset
156- self .dataset = ds .dataset (
157- self .source ,
158- schema = self .schema ,
159- format = "parquet" ,
160- partitioning = "hive" ,
161- filesystem = self .filesystem ,
162- )
159+ # reset paths from original location before load
160+ _ , self .paths = self .parse_location (self .location )
161+
162+ # perform initial load of full dataset
163+ self ._load_pyarrow_dataset ()
164+
165+ # if current_records flag set, limit to parquet files associated with current runs
166+ self ._dedupe_on_read = current_records
167+ if current_records :
168+ timdex_run_manager = TIMDEXRunManager (timdex_dataset = self )
169+
170+ # update paths, limiting by source if set
171+ self .paths = timdex_run_manager .get_current_parquet_files (
172+ source = filters .get ("source" )
173+ )
174+
175+ # reload pyarrow dataset
176+ self ._load_pyarrow_dataset ()
163177
164178 # filter dataset
165179 self .dataset = self ._get_filtered_dataset (** filters )
@@ -169,6 +183,16 @@ def load(
169183 f"{ round (time .perf_counter ()- start_time , 2 )} s"
170184 )
171185
186+ def _load_pyarrow_dataset (self ) -> None :
187+ """Load the pyarrow dataset per local filesystem and paths attributes."""
188+ self .dataset = ds .dataset (
189+ self .paths ,
190+ schema = self .schema ,
191+ format = "parquet" ,
192+ partitioning = "hive" ,
193+ filesystem = self .filesystem ,
194+ )
195+
172196 def _get_filtered_dataset (
173197 self ,
174198 ** filters : Unpack [DatasetFilters ],
@@ -345,7 +369,8 @@ def write(
345369 start_time = time .perf_counter ()
346370 self ._written_files = []
347371
348- if isinstance (self .source , list ):
372+ dataset_filesystem , dataset_path = self .parse_location (self .location )
373+ if isinstance (dataset_path , list ):
349374 raise TypeError (
350375 "Dataset location must be the root of a single dataset for writing"
351376 )
@@ -354,10 +379,10 @@ def write(
354379
355380 ds .write_dataset (
356381 record_batches_iter ,
357- base_dir = self . source ,
382+ base_dir = dataset_path ,
358383 basename_template = "%s-{i}.parquet" % (str (uuid .uuid4 ())), # noqa: UP031
359384 existing_data_behavior = "overwrite_or_ignore" ,
360- filesystem = self . filesystem ,
385+ filesystem = dataset_filesystem ,
361386 file_visitor = lambda written_file : self ._written_files .append (written_file ), # type: ignore[arg-type]
362387 format = "parquet" ,
363388 max_open_files = 500 ,
@@ -444,14 +469,55 @@ def read_batches_iter(
444469 "Dataset is not loaded. Please call the `load` method first."
445470 )
446471 dataset = self ._get_filtered_dataset (** filters )
447- for batch in dataset .to_batches (
472+
473+ batches = dataset .to_batches (
448474 columns = columns ,
449475 batch_size = self .config .read_batch_size ,
450476 batch_readahead = self .config .batch_read_ahead ,
451477 fragment_readahead = self .config .fragment_read_ahead ,
452- ):
478+ )
479+
480+ if self ._dedupe_on_read :
481+ yield from self ._yield_deduped_batches (batches )
482+ else :
483+ for batch in batches :
484+ if len (batch ) > 0 :
485+ yield batch
486+
487+ def _yield_deduped_batches (
488+ self , batches : Iterator [pa .RecordBatch ]
489+ ) -> Iterator [pa .RecordBatch ]:
490+ """Method to yield record deduped batches.
491+
492+ Extending the normal behavior of yielding batches untouched, this method keeps
493+ track of seen timdex_record_id's, yielding them only once. For this method to
494+ yield the most current version of a record -- most common usage -- it is required
495+ that the batches are pre-ordered so the most recent record version is encountered
496+ first.
497+ """
498+ seen_records = set ()
499+ for batch in batches :
453500 if len (batch ) > 0 :
454- yield batch
501+ # init list of batch indices for records unseen
502+ unseen_batch_indices = []
503+
504+ # get list of timdex ids from batch
505+ timdex_ids = batch .column ("timdex_record_id" ).to_pylist ()
506+
507+ # check each record id and track unseen ones
508+ for i , record_id in enumerate (timdex_ids ):
509+ if record_id not in seen_records :
510+ unseen_batch_indices .append (i )
511+ seen_records .add (record_id )
512+
513+ # if all records from batch were seen, continue
514+ if not unseen_batch_indices :
515+ continue
516+
517+ # else, yield unseen records from batch
518+ deduped_batch = batch .take (pa .array (unseen_batch_indices )) # type: ignore[arg-type]
519+ if len (deduped_batch ) > 0 :
520+ yield deduped_batch
455521
456522 def read_dataframes_iter (
457523 self ,
@@ -513,13 +579,14 @@ def read_transformed_records_iter(
513579 ) -> Iterator [dict ]:
514580 """Yield individual transformed records as dictionaries from the dataset.
515581
516- If 'transformed_record' is None (i.e., action="skip"|"error"), the yield
517- statement will not be executed for the row.
582+ If 'transformed_record' is None (common scenarios are action="skip"|"error"), the
583+ yield statement will not be executed for the row. Note that for action="delete" a
584+ transformed record still may be yielded if present.
518585
519586 Args: see self.read_batches_iter()
520587 """
521588 for record_dict in self .read_dicts_iter (
522- columns = ["transformed_record" ],
589+ columns = ["timdex_record_id" , " transformed_record" ],
523590 ** filters ,
524591 ):
525592 if transformed_record := record_dict ["transformed_record" ]:
0 commit comments