Skip to content

Commit 29d1e0e

Browse files
committed
Add LIMIT clause to read methods
Why these changes are being introduced: Sometimes it can be helpful to limit the results from a read method. How this addresses that need: Adds optional limit= arg to all read methods which is passed along to the metadata query. By limiting the metadata results, we limit the data records retrieved. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-543
1 parent 4b2f2d2 commit 29d1e0e

3 files changed

Lines changed: 42 additions & 6 deletions

File tree

tests/test_read.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,9 @@ def test_dataset_load_current_records_gets_correct_same_day_daily_runs_ordering(
276276
# just assert it's one of the daily runs
277277
assert first_record["run_id"] in {"run-4", "run-5"}
278278
assert first_record["action"] in {"index", "delete"}
279+
280+
281+
def test_read_batches_iter_limit_returns_n_rows(timdex_dataset_multi_source):
282+
batches = timdex_dataset_multi_source.read_batches_iter(limit=10)
283+
table = pa.Table.from_batches(batches)
284+
assert len(table) == 10

timdex_dataset_api/dataset.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def read_batches_iter(
358358
self,
359359
table: str = "records",
360360
columns: list[str] | None = None,
361+
limit: int | None = None,
361362
where: str | None = None,
362363
**filters: Unpack[DatasetFilters],
363364
) -> Iterator[pa.RecordBatch]:
@@ -375,6 +376,7 @@ def read_batches_iter(
375376
Args:
376377
- table: an available DuckDB view or table
377378
- columns: list of columns to return
379+
- limit: limit number of records yielded
378380
- where: raw SQL WHERE clause that can be used alone, or in combination with
379381
key/value DatasetFilters
380382
- filters: simple filtering based on key/value pairs from DatasetFilters
@@ -383,7 +385,7 @@ def read_batches_iter(
383385

384386
# build and execute metadata query
385387
metadata_time = time.perf_counter()
386-
meta_query = self.metadata.build_meta_query(table, where, **filters)
388+
meta_query = self.metadata.build_meta_query(table, limit, where, **filters)
387389
meta_df = self.metadata.conn.query(meta_query).to_df()
388390
meta_df = meta_df.sort_values(by=["filename", "run_record_offset"])
389391
logger.debug(
@@ -472,25 +474,35 @@ def read_dataframes_iter(
472474
self,
473475
table: str = "records",
474476
columns: list[str] | None = None,
477+
limit: int | None = None,
475478
where: str | None = None,
476479
**filters: Unpack[DatasetFilters],
477480
) -> Iterator[pd.DataFrame]:
478481
for record_batch in self.read_batches_iter(
479-
table=table, columns=columns, where=where, **filters
482+
table=table,
483+
columns=columns,
484+
limit=limit,
485+
where=where,
486+
**filters,
480487
):
481488
yield record_batch.to_pandas()
482489

483490
def read_dataframe(
484491
self,
485492
table: str = "records",
486493
columns: list[str] | None = None,
494+
limit: int | None = None,
487495
where: str | None = None,
488496
**filters: Unpack[DatasetFilters],
489497
) -> pd.DataFrame | None:
490498
df_batches = [
491499
record_batch.to_pandas()
492500
for record_batch in self.read_batches_iter(
493-
table=table, columns=columns, where=where, **filters
501+
table=table,
502+
columns=columns,
503+
limit=limit,
504+
where=where,
505+
**filters,
494506
)
495507
]
496508
if not df_batches:
@@ -501,22 +513,32 @@ def read_dicts_iter(
501513
self,
502514
table: str = "records",
503515
columns: list[str] | None = None,
516+
limit: int | None = None,
504517
where: str | None = None,
505518
**filters: Unpack[DatasetFilters],
506519
) -> Iterator[dict]:
507520
for record_batch in self.read_batches_iter(
508-
table=table, columns=columns, where=where, **filters
521+
table=table,
522+
columns=columns,
523+
limit=limit,
524+
where=where,
525+
**filters,
509526
):
510527
yield from record_batch.to_pylist()
511528

512529
def read_transformed_records_iter(
513530
self,
514531
table: str = "records",
532+
limit: int | None = None,
515533
where: str | None = None,
516534
**filters: Unpack[DatasetFilters],
517535
) -> Iterator[dict]:
518536
for record_dict in self.read_dicts_iter(
519-
table=table, columns=["transformed_record"], where=where, **filters
537+
table=table,
538+
columns=["transformed_record"],
539+
limit=limit,
540+
where=where,
541+
**filters,
520542
):
521543
if transformed_record := record_dict["transformed_record"]:
522544
yield json.loads(transformed_record)

timdex_dataset_api/metadata.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,11 @@ def write_append_delta_duckdb(self, filepath: str) -> None:
612612
)
613613

614614
def build_meta_query(
615-
self, table: str, where: str | None, **filters: Unpack["DatasetFilters"]
615+
self,
616+
table: str,
617+
limit: int | None,
618+
where: str | None,
619+
**filters: Unpack["DatasetFilters"],
616620
) -> str:
617621
"""Build SQL query using SQLAlchemy against metadata schema tables and views."""
618622
sa_table = self.get_sa_table(table)
@@ -638,6 +642,10 @@ def build_meta_query(
638642
if combined is not None:
639643
stmt = stmt.where(combined)
640644

645+
# apply limit if present
646+
if limit:
647+
stmt = stmt.limit(limit)
648+
641649
# using DuckDB dialect, compile to SQL string
642650
compiled = stmt.compile(
643651
dialect=DuckDBDialect(),

0 commit comments

Comments
 (0)