Skip to content

Commit ad3f6bd

Browse files
committed
Surface pyarrow batch reading parameters and adjust defaults
Why these changes are being introduced: During use of this library for batch reading transformed records to index into Opensearch, the application TIM threw an out-of-memory error. It was observed that, as-is, batch reading of records could hover around 1-2gb. This was surprising, as we are careful to only yield records as they are batch read from the parquet dataset. It turns out that pyarrow Dataset.to_batches() has some defaults for optimistic reading ahead to improve IO, but at the cost of memory consumption. Tuning these down resulted in much lower memory consumption, that aligns with how our current TIMDEX applications are resourced. How this addresses that need: By surfacing the Dataset.to_batches() arguments 'batch_readahead' and 'fragment_readahead' to this library's read methods, and setting conservative defaults, memory consumption is significantly lower. Side effects of this change: * Per the defaults set, slower IO for dataset reads. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-468
1 parent 41b326f commit ad3f6bd

1 file changed

Lines changed: 48 additions & 5 deletions

File tree

timdex_dataset_api/dataset.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class DatasetFilters(TypedDict, total=False):
6464
DEFAULT_BATCH_SIZE = 1_000
6565
MAX_ROWS_PER_GROUP = DEFAULT_BATCH_SIZE
6666
MAX_ROWS_PER_FILE = 100_000
67+
DEFAULT_BATCH_READ_AHEAD = 0
68+
DEFAULT_FRAGMENT_READ_AHEAD = 0
6769

6870

6971
def strict_date_parse(date_string: str) -> date:
@@ -394,6 +396,8 @@ def read_batches_iter(
394396
self,
395397
columns: list[str] | None = None,
396398
batch_size: int = DEFAULT_BATCH_SIZE,
399+
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
400+
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
397401
**filters: Unpack[DatasetFilters],
398402
) -> Iterator[pa.RecordBatch]:
399403
"""Yield pyarrow.RecordBatches from the dataset.
@@ -404,36 +408,57 @@ def read_batches_iter(
404408
Args:
405409
- columns: list[str], list of columns to return from the dataset
406410
- batch_size: int, max number of rows to yield per batch
411+
- batch_read_ahead: int, the number of batches to read ahead in a file. This
412+
might not work for all file formats. Increasing this number will increase
413+
RAM usage but could also improve IO utilization. Pyarrow default is 16,
414+
but this library defaults to 0 to prioritize memory footprint.
415+
- fragment_read_ahead: int, The number of files to read ahead. Increasing this
416+
number will increase RAM usage but could also improve IO utilization.
417+
Pyarrow default is 4, but this library defaults to 0 to prioritize memory
418+
footprint.
407419
- filter_kwargs: pairs of column:value to filter the dataset
408420
"""
409421
if not self.dataset:
410422
raise DatasetNotLoadedError(
411423
"Dataset is not loaded. Please call the `load` method first."
412424
)
413425
dataset = self._get_filtered_dataset(**filters)
414-
for batch in dataset.to_batches(columns=columns, batch_size=batch_size):
426+
for batch in dataset.to_batches(
427+
columns=columns,
428+
batch_size=batch_size,
429+
batch_readahead=batch_read_ahead,
430+
fragment_readahead=fragment_read_ahead,
431+
):
415432
if len(batch) > 0:
416433
yield batch
417434

418435
def read_dataframes_iter(
419436
self,
420437
columns: list[str] | None = None,
421438
batch_size: int = DEFAULT_BATCH_SIZE,
439+
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
440+
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
422441
**filters: Unpack[DatasetFilters],
423442
) -> Iterator[pd.DataFrame]:
424443
"""Yield record batches as Pandas DataFrames from the dataset.
425444
426445
Args: see self.read_batches_iter()
427446
"""
428447
for record_batch in self.read_batches_iter(
429-
columns=columns, batch_size=batch_size, **filters
448+
columns=columns,
449+
batch_size=batch_size,
450+
batch_read_ahead=batch_read_ahead,
451+
fragment_read_ahead=fragment_read_ahead,
452+
**filters,
430453
):
431454
yield record_batch.to_pandas()
432455

433456
def read_dataframe(
434457
self,
435458
columns: list[str] | None = None,
436459
batch_size: int = DEFAULT_BATCH_SIZE,
460+
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
461+
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
437462
**filters: Unpack[DatasetFilters],
438463
) -> pd.DataFrame | None:
439464
"""Yield record batches as Pandas DataFrames and concatenate to single dataframe.
@@ -447,7 +472,11 @@ def read_dataframe(
447472
df_batches = [
448473
record_batch.to_pandas()
449474
for record_batch in self.read_batches_iter(
450-
columns=columns, batch_size=batch_size, **filters
475+
columns=columns,
476+
batch_size=batch_size,
477+
batch_read_ahead=batch_read_ahead,
478+
fragment_read_ahead=fragment_read_ahead,
479+
**filters,
451480
)
452481
]
453482
if not df_batches:
@@ -458,29 +487,43 @@ def read_dicts_iter(
458487
self,
459488
columns: list[str] | None = None,
460489
batch_size: int = DEFAULT_BATCH_SIZE,
490+
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
491+
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
461492
**filters: Unpack[DatasetFilters],
462493
) -> Iterator[dict]:
463494
"""Yield individual record rows as dictionaries from the dataset.
464495
465496
Args: see self.read_batches_iter()
466497
"""
467498
for record_batch in self.read_batches_iter(
468-
columns=columns, batch_size=batch_size, **filters
499+
columns=columns,
500+
batch_size=batch_size,
501+
batch_read_ahead=batch_read_ahead,
502+
fragment_read_ahead=fragment_read_ahead,
503+
**filters,
469504
):
470505
yield from record_batch.to_pylist()
471506

472507
def read_transformed_records_iter(
473508
self,
474509
batch_size: int = DEFAULT_BATCH_SIZE,
510+
batch_read_ahead: int = DEFAULT_BATCH_READ_AHEAD,
511+
fragment_read_ahead: int = DEFAULT_FRAGMENT_READ_AHEAD,
475512
**filters: Unpack[DatasetFilters],
476513
) -> Iterator[dict]:
477514
"""Yield individual transformed records as dictionaries from the dataset.
478515
479516
If 'transformed_record' is None (i.e., action="skip"|"error"), the yield
480517
statement will not be executed for the row.
518+
519+
Args: see self.read_batches_iter()
481520
"""
482521
for record_dict in self.read_dicts_iter(
483-
columns=["transformed_record"], batch_size=batch_size, **filters
522+
columns=["transformed_record"],
523+
batch_size=batch_size,
524+
batch_read_ahead=batch_read_ahead,
525+
fragment_read_ahead=fragment_read_ahead,
526+
**filters,
484527
):
485528
if transformed_record := record_dict["transformed_record"]:
486529
yield json.loads(transformed_record)

0 commit comments

Comments
 (0)