33import itertools
44import json
55import operator
6+ import os
67import time
78import uuid
89from collections .abc import Iterator
10+ from dataclasses import dataclass , field
911from datetime import UTC , date , datetime
1012from functools import reduce
1113from typing import TYPE_CHECKING , TypedDict , Unpack
@@ -61,27 +63,57 @@ class DatasetFilters(TypedDict, total=False):
6163 day : str | None
6264
6365
64- DEFAULT_BATCH_SIZE = 1_000
65- MAX_ROWS_PER_GROUP = DEFAULT_BATCH_SIZE
66- MAX_ROWS_PER_FILE = 100_000
67- DEFAULT_BATCH_READ_AHEAD = 0
68- DEFAULT_FRAGMENT_READ_AHEAD = 0
66+ @dataclass
67+ class TIMDEXDatasetConfig :
68+ """Configurations for dataset operations.
6969
70+ - read_batch_size: row size of batches read, affecting memory consumption
71+ - write_batch_size: row size of batches written, directly affecting row group size in
72+ final parquet files
73+ - max_rows_per_group: max number of rows per row group in a parquet file
74+ - max_rows_per_file: max number of rows in a single parquet file
75+ - batch_read_ahead: number of batches to optimistically read ahead when batch reading
76+ from a dataset; pyarrow default is 16
77+ - fragment_read_ahead: number of fragments to optimistically read ahead when batch
78+ reaching from a dataset; pyarrow default is 4
79+ """
7080
71- def strict_date_parse (date_string : str ) -> date :
72- return datetime .strptime (date_string , "%Y-%m-%d" ).astimezone (UTC ).date ()
81+ read_batch_size : int = field (
82+ default_factory = lambda : int (os .getenv ("TDA_READ_BATCH_SIZE" , "1_000" ))
83+ )
84+ write_batch_size : int = field (
85+ default_factory = lambda : int (os .getenv ("TDA_WRITE_BATCH_SIZE" , "1_000" ))
86+ )
87+ max_rows_per_group : int = field (
88+ default_factory = lambda : int (os .getenv ("TDA_MAX_ROWS_PER_GROUP" , "1_000" ))
89+ )
90+ max_rows_per_file : int = field (
91+ default_factory = lambda : int (os .getenv ("TDA_MAX_ROWS_PER_FILE" , "100_000" ))
92+ )
93+ batch_read_ahead : int = field (
94+ default_factory = lambda : int (os .getenv ("TDA_BATCH_READ_AHEAD" , "0" ))
95+ )
96+ fragment_read_ahead : int = field (
97+ default_factory = lambda : int (os .getenv ("TDA_FRAGMENT_READ_AHEAD" , "0" ))
98+ )
7399
74100
75101class TIMDEXDataset :
76102
77- def __init__ (self , location : str | list [str ]):
103+ def __init__ (
104+ self ,
105+ location : str | list [str ],
106+ config : TIMDEXDatasetConfig | None = None ,
107+ ):
78108 """Initialize TIMDEXDataset object.
79109
80110 Args:
81111 location (str | list[str]): Local filesystem path or an S3 URI to
82112 a parquet dataset. For partitioned datasets, set to the base directory.
83113 """
84114 self .location = location
115+ self .config = config or TIMDEXDatasetConfig ()
116+
85117 self .filesystem , self .source = self .parse_location (self .location )
86118 self .dataset : ds .Dataset = None # type: ignore[assignment]
87119 self .schema = TIMDEX_DATASET_SCHEMA
@@ -171,7 +203,7 @@ def _get_filtered_dataset(
171203
172204 # create filter expressions for element-wise equality comparisons
173205 expressions = []
174- for field , value in filters .items ():
206+ for field , value in filters .items (): # noqa: F402
175207 if isinstance (value , list ):
176208 expressions .append (ds .field (field ).isin (value ))
177209 else :
@@ -207,7 +239,7 @@ def _parse_date_filters(self, run_date: str | date | None) -> DatasetFilters:
207239 DatasetFilters[dict]: values for run_date, year, month, and day
208240 """
209241 if isinstance (run_date , str ):
210- run_date_obj = strict_date_parse (run_date )
242+ run_date_obj = datetime . strptime (run_date , "%Y-%m-%d" ). astimezone ( UTC ). date ( )
211243 elif isinstance (run_date , date ):
212244 run_date_obj = run_date
213245 else :
@@ -286,7 +318,6 @@ def write(
286318 self ,
287319 records_iter : Iterator ["DatasetRecord" ],
288320 * ,
289- batch_size : int = DEFAULT_BATCH_SIZE ,
290321 use_threads : bool = True ,
291322 ) -> list [ds .WrittenFile ]:
292323 """Write records to the TIMDEX parquet dataset.
@@ -309,8 +340,6 @@ def write(
309340
310341 Args:
311342 - records_iter: Iterator of DatasetRecord instances
312- - batch_size: size for batches to yield and write, directly affecting row
313- group size in final parquet files
314343 - use_threads: boolean if threads should be used for writing
315344 """
316345 start_time = time .perf_counter ()
@@ -321,10 +350,7 @@ def write(
321350 "Dataset location must be the root of a single dataset for writing"
322351 )
323352
324- record_batches_iter = self .create_record_batches (
325- records_iter ,
326- batch_size = batch_size ,
327- )
353+ record_batches_iter = self .create_record_batches (records_iter )
328354
329355 ds .write_dataset (
330356 record_batches_iter ,
@@ -335,8 +361,8 @@ def write(
335361 file_visitor = lambda written_file : self ._written_files .append (written_file ), # type: ignore[arg-type]
336362 format = "parquet" ,
337363 max_open_files = 500 ,
338- max_rows_per_file = MAX_ROWS_PER_FILE ,
339- max_rows_per_group = MAX_ROWS_PER_GROUP ,
364+ max_rows_per_file = self . config . max_rows_per_file ,
365+ max_rows_per_group = self . config . max_rows_per_group ,
340366 partitioning = self .partition_columns ,
341367 partitioning_flavor = "hive" ,
342368 schema = self .schema ,
@@ -349,8 +375,6 @@ def write(
349375 def create_record_batches (
350376 self ,
351377 records_iter : Iterator ["DatasetRecord" ],
352- * ,
353- batch_size : int = DEFAULT_BATCH_SIZE ,
354378 ) -> Iterator [pa .RecordBatch ]:
355379 """Yield pyarrow.RecordBatches for writing.
356380
@@ -361,10 +385,10 @@ def create_record_batches(
361385
362386 Args:
363387 - records_iter: Iterator of DatasetRecord instances
364- - batch_size: size for batches to yield and write, directly affecting row
365- group size in final parquet files
366388 """
367- for i , record_batch in enumerate (itertools .batched (records_iter , batch_size )):
389+ for i , record_batch in enumerate (
390+ itertools .batched (records_iter , self .config .write_batch_size )
391+ ):
368392 batch = pa .RecordBatch .from_pylist (
369393 [record .to_dict () for record in record_batch ]
370394 )
@@ -395,9 +419,6 @@ def log_write_statistics(self, start_time: float) -> None:
395419 def read_batches_iter (
396420 self ,
397421 columns : list [str ] | None = None ,
398- batch_size : int = DEFAULT_BATCH_SIZE ,
399- batch_read_ahead : int = DEFAULT_BATCH_READ_AHEAD ,
400- fragment_read_ahead : int = DEFAULT_FRAGMENT_READ_AHEAD ,
401422 ** filters : Unpack [DatasetFilters ],
402423 ) -> Iterator [pa .RecordBatch ]:
403424 """Yield pyarrow.RecordBatches from the dataset.
@@ -416,7 +437,7 @@ def read_batches_iter(
416437 number will increase RAM usage but could also improve IO utilization.
417438 Pyarrow default is 4, but this library defaults to 0 to prioritize memory
418439 footprint.
419- - filter_kwargs : pairs of column:value to filter the dataset
440+ - filters : pairs of column:value to filter the dataset
420441 """
421442 if not self .dataset :
422443 raise DatasetNotLoadedError (
@@ -425,19 +446,16 @@ def read_batches_iter(
425446 dataset = self ._get_filtered_dataset (** filters )
426447 for batch in dataset .to_batches (
427448 columns = columns ,
428- batch_size = batch_size ,
429- batch_readahead = batch_read_ahead ,
430- fragment_readahead = fragment_read_ahead ,
449+ batch_size = self . config . read_batch_size ,
450+ batch_readahead = self . config . batch_read_ahead ,
451+ fragment_readahead = self . config . fragment_read_ahead ,
431452 ):
432453 if len (batch ) > 0 :
433454 yield batch
434455
435456 def read_dataframes_iter (
436457 self ,
437458 columns : list [str ] | None = None ,
438- batch_size : int = DEFAULT_BATCH_SIZE ,
439- batch_read_ahead : int = DEFAULT_BATCH_READ_AHEAD ,
440- fragment_read_ahead : int = DEFAULT_FRAGMENT_READ_AHEAD ,
441459 ** filters : Unpack [DatasetFilters ],
442460 ) -> Iterator [pd .DataFrame ]:
443461 """Yield record batches as Pandas DataFrames from the dataset.
@@ -446,19 +464,13 @@ def read_dataframes_iter(
446464 """
447465 for record_batch in self .read_batches_iter (
448466 columns = columns ,
449- batch_size = batch_size ,
450- batch_read_ahead = batch_read_ahead ,
451- fragment_read_ahead = fragment_read_ahead ,
452467 ** filters ,
453468 ):
454469 yield record_batch .to_pandas ()
455470
456471 def read_dataframe (
457472 self ,
458473 columns : list [str ] | None = None ,
459- batch_size : int = DEFAULT_BATCH_SIZE ,
460- batch_read_ahead : int = DEFAULT_BATCH_READ_AHEAD ,
461- fragment_read_ahead : int = DEFAULT_FRAGMENT_READ_AHEAD ,
462474 ** filters : Unpack [DatasetFilters ],
463475 ) -> pd .DataFrame | None :
464476 """Yield record batches as Pandas DataFrames and concatenate to single dataframe.
@@ -473,9 +485,6 @@ def read_dataframe(
473485 record_batch .to_pandas ()
474486 for record_batch in self .read_batches_iter (
475487 columns = columns ,
476- batch_size = batch_size ,
477- batch_read_ahead = batch_read_ahead ,
478- fragment_read_ahead = fragment_read_ahead ,
479488 ** filters ,
480489 )
481490 ]
@@ -486,9 +495,6 @@ def read_dataframe(
486495 def read_dicts_iter (
487496 self ,
488497 columns : list [str ] | None = None ,
489- batch_size : int = DEFAULT_BATCH_SIZE ,
490- batch_read_ahead : int = DEFAULT_BATCH_READ_AHEAD ,
491- fragment_read_ahead : int = DEFAULT_FRAGMENT_READ_AHEAD ,
492498 ** filters : Unpack [DatasetFilters ],
493499 ) -> Iterator [dict ]:
494500 """Yield individual record rows as dictionaries from the dataset.
@@ -497,18 +503,12 @@ def read_dicts_iter(
497503 """
498504 for record_batch in self .read_batches_iter (
499505 columns = columns ,
500- batch_size = batch_size ,
501- batch_read_ahead = batch_read_ahead ,
502- fragment_read_ahead = fragment_read_ahead ,
503506 ** filters ,
504507 ):
505508 yield from record_batch .to_pylist ()
506509
507510 def read_transformed_records_iter (
508511 self ,
509- batch_size : int = DEFAULT_BATCH_SIZE ,
510- batch_read_ahead : int = DEFAULT_BATCH_READ_AHEAD ,
511- fragment_read_ahead : int = DEFAULT_FRAGMENT_READ_AHEAD ,
512512 ** filters : Unpack [DatasetFilters ],
513513 ) -> Iterator [dict ]:
514514 """Yield individual transformed records as dictionaries from the dataset.
@@ -520,9 +520,6 @@ def read_transformed_records_iter(
520520 """
521521 for record_dict in self .read_dicts_iter (
522522 columns = ["transformed_record" ],
523- batch_size = batch_size ,
524- batch_read_ahead = batch_read_ahead ,
525- fragment_read_ahead = fragment_read_ahead ,
526523 ** filters ,
527524 ):
528525 if transformed_record := record_dict ["transformed_record" ]:
0 commit comments