Skip to content

Commit 184cddc

Browse files
authored
Merge pull request #158 from MITLibraries/TIMX-527-write-append-deltas
TIMX 527 - write append deltas
2 parents 3efe8b7 + d4931d5 commit 184cddc

8 files changed

Lines changed: 416 additions & 233 deletions

File tree

Pipfile.lock

Lines changed: 121 additions & 141 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,16 @@ WARNING_ONLY_LOGGERS=# Comma-seperated list of logger names to set as WARNING on
4949
MINIO_S3_ENDPOINT_URL=# If set, informs the library to use this Minio S3 instance. Requires the http(s):// protocol.
5050
MINIO_USERNAME=# Username / AWS Key for Minio; required when MINIO_S3_ENDPOINT_URL is set
5151
MINIO_PASSWORD=# Pasword / AWS Secret for Minio; required when MINIO_S3_ENDPOINT_URL is set
52-
MINIO_DATA=# Path to persist MinIO data if started via Makefile command
52+
MINIO_DATA=# Path to persist MinIO data if started via Makefile command
53+
54+
TDA_READ_BATCH_SIZE=# Row size of batches read, affecting memory consumption
55+
TDA_WRITE_BATCH_SIZE=# Row size of batches written, directly affecting row group size in final parquet files
56+
TDA_MAX_ROWS_PER_GROUP=# Max number of rows per row group in a parquet file
57+
TDA_MAX_ROWS_PER_FILE=# Max number of rows in a single parquet file
58+
TDA_BATCH_READ_AHEAD=# Number of batches to optimistically read ahead when batch reading from a dataset; pyarrow default is 16
59+
TDA_FRAGMENT_READ_AHEAD=# Number of fragments to optimistically read ahead when batch reaching from a dataset; pyarrow default is 4
60+
TDA_DUCKDB_MEMORY_LIMIT=# Memory limit for DuckDB connection
61+
TDA_DUCKDB_THREADS=# Thread limit for DuckDB connection
5362
```
5463

5564
## Local S3 via MinIO

tests/test_dataset.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: D205, D209, SLF001, PLR2004
22

3+
import glob
34
import os
45
from datetime import date
56
from unittest.mock import MagicMock, patch
@@ -18,11 +19,24 @@
1819
@pytest.mark.parametrize(
1920
("location", "expected_file_system", "expected_source"),
2021
[
21-
("path/to/dataset", fs.LocalFileSystem, "path/to/dataset"),
22-
("s3://bucket/path/to/dataset", fs.S3FileSystem, "bucket/path/to/dataset"),
22+
(
23+
"path/to/dataset",
24+
fs.LocalFileSystem,
25+
"path/to/dataset/data/records",
26+
),
27+
(
28+
"s3://timdex/path/to/dataset",
29+
fs.S3FileSystem,
30+
"timdex/path/to/dataset/data/records",
31+
),
2332
],
2433
)
25-
def test_dataset_init_success(location, expected_file_system, expected_source):
34+
def test_dataset_init_success(
35+
location,
36+
expected_file_system,
37+
expected_source,
38+
mocked_timdex_bucket,
39+
):
2640
timdex_dataset = TIMDEXDataset(location=location)
2741
assert isinstance(timdex_dataset.filesystem, expected_file_system)
2842
assert timdex_dataset.paths == expected_source
@@ -58,7 +72,7 @@ def test_dataset_load_local_sets_filesystem_and_dataset_success(
5872
result = timdex_dataset.load()
5973

6074
mock_pyarrow_ds.assert_called_once_with(
61-
"local/path/to/dataset",
75+
"local/path/to/dataset/data/records",
6276
schema=timdex_dataset.schema,
6377
format="parquet",
6478
partitioning="hive",
@@ -72,16 +86,16 @@ def test_dataset_load_local_sets_filesystem_and_dataset_success(
7286
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
7387
@patch("timdex_dataset_api.dataset.ds.dataset")
7488
def test_dataset_load_s3_sets_filesystem_and_dataset_success(
75-
mock_pyarrow_ds, mock_get_s3_fs
89+
mock_pyarrow_ds, mock_get_s3_fs, mocked_timdex_bucket
7690
):
7791
mock_get_s3_fs.return_value = MagicMock()
7892
mock_pyarrow_ds.return_value = MagicMock()
7993

80-
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
94+
timdex_dataset = TIMDEXDataset(location="s3://timdex/path/to/dataset")
8195
result = timdex_dataset.load()
8296

8397
mock_pyarrow_ds.assert_called_with(
84-
"bucket/path/to/dataset",
98+
"timdex/path/to/dataset/data/records",
8599
schema=timdex_dataset.schema,
86100
format="parquet",
87101
partitioning="hive",
@@ -497,3 +511,14 @@ def test_dataset_load_current_records_gets_correct_same_day_daily_runs_ordering(
497511

498512
assert first_record["run_id"] == "run-5"
499513
assert first_record["action"] == "delete"
514+
515+
516+
def test_dataset_records_data_structure_is_idempotent(dataset_with_runs):
517+
assert os.path.exists(dataset_with_runs.data_records_root)
518+
start_file_count = glob.glob(f"{dataset_with_runs.data_records_root}/**/*")
519+
520+
dataset_with_runs.create_data_structure()
521+
522+
assert os.path.exists(dataset_with_runs.data_records_root)
523+
end_file_count = glob.glob(f"{dataset_with_runs.data_records_root}/**/*")
524+
assert start_file_count == end_file_count

tests/test_metadata.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
import glob
2+
import os
3+
from pathlib import Path
4+
15
from duckdb import DuckDBPyConnection
26

37
from timdex_dataset_api import TIMDEXDatasetMetadata
48

59

610
def test_tdm_init_no_metadata_file_warning_success(caplog, dataset_with_runs_location):
7-
tdm = TIMDEXDatasetMetadata(dataset_with_runs_location)
11+
TIMDEXDatasetMetadata(dataset_with_runs_location)
812

9-
assert tdm.conn is None
1013
assert "Static metadata database not found" in caplog.text
1114

1215

13-
def test_tdm_local_dataset_structure_properties():
14-
local_root = "/path/to/nothing"
16+
def test_tdm_local_dataset_structure_properties(tmp_path):
17+
local_root = str(Path(tmp_path) / "path/to/nothing")
1518
tdm_local = TIMDEXDatasetMetadata(local_root)
1619
assert tdm_local.location == local_root
1720
assert tdm_local.location_scheme == "file"
@@ -44,3 +47,14 @@ def test_tdm_connection_static_database_records_table_exists(timdex_dataset_meta
4447
"""select * from static_db.records;"""
4548
).to_df()
4649
assert len(records_df) > 0
50+
51+
52+
def test_dataset_metadata_structure_is_idempotent(timdex_dataset_metadata):
53+
assert os.path.exists(timdex_dataset_metadata.metadata_root)
54+
start_file_count = glob.glob(f"{timdex_dataset_metadata.metadata_root}/**/*")
55+
56+
timdex_dataset_metadata.create_metadata_structure()
57+
58+
assert os.path.exists(timdex_dataset_metadata.metadata_root)
59+
end_file_count = glob.glob(f"{timdex_dataset_metadata.metadata_root}/**/*")
60+
assert start_file_count == end_file_count

tests/test_write.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# ruff: noqa: PLR2004, D209, D205
22
import math
33
import os
4+
from pathlib import Path
45
from unittest.mock import patch
56

67
import pyarrow.dataset as ds
8+
import pyarrow.parquet as pq
79
import pytest
810

911
from tests.utils import generate_sample_records
1012
from timdex_dataset_api.dataset import (
1113
TIMDEX_DATASET_SCHEMA,
1214
TIMDEXDataset,
1315
)
16+
from timdex_dataset_api.metadata import ORDERED_METADATA_COLUMN_NAMES
1417

1518

1619
def test_dataset_write_records_to_new_local_dataset(
@@ -144,3 +147,38 @@ def test_dataset_write_partition_overwrite_files_with_same_name(
144147
# assert that only the second file exists and overwriting occurs
145148
assert os.path.exists(written_files_source_a1[0].path)
146149
assert new_local_dataset.row_count == 7
150+
151+
152+
def test_dataset_write_single_append_delta_success(
153+
new_local_dataset, sample_records_iter
154+
):
155+
written_files = new_local_dataset.write(sample_records_iter(1_000))
156+
append_deltas = os.listdir(new_local_dataset.metadata.append_deltas_path)
157+
158+
assert len(append_deltas) == len(written_files)
159+
160+
161+
def test_dataset_write_multiple_append_deltas_success(
162+
new_local_dataset, sample_records_iter
163+
):
164+
"""Expecting 10 ETL parquet files written, and so 10 append deltas as well."""
165+
new_local_dataset.config.max_rows_per_file = 100
166+
new_local_dataset.config.max_rows_per_group = 100
167+
168+
written_files = new_local_dataset.write(sample_records_iter(1_000))
169+
append_deltas = os.listdir(new_local_dataset.metadata.append_deltas_path)
170+
171+
assert len(written_files) == 10
172+
assert len(append_deltas) == len(written_files)
173+
174+
175+
def test_dataset_write_append_delta_expected_metadata_columns(
176+
new_local_dataset, sample_records_iter
177+
):
178+
new_local_dataset.write(sample_records_iter(1_000))
179+
append_delta_filepath = os.listdir(new_local_dataset.metadata.append_deltas_path)[0]
180+
181+
append_delta = pq.ParquetFile(
182+
new_local_dataset.metadata.append_deltas_path / Path(append_delta_filepath)
183+
)
184+
assert append_delta.schema.names == ORDERED_METADATA_COLUMN_NAMES

timdex_dataset_api/dataset.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from dataclasses import dataclass, field
1111
from datetime import UTC, date, datetime
1212
from functools import reduce
13-
from typing import TYPE_CHECKING, TypedDict, Unpack
13+
from pathlib import Path
14+
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
15+
from urllib.parse import urlparse
1416

1517
import boto3
1618
import pandas as pd
@@ -20,6 +22,7 @@
2022

2123
from timdex_dataset_api.config import configure_logger
2224
from timdex_dataset_api.exceptions import DatasetNotLoadedError
25+
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
2326

2427
if TYPE_CHECKING:
2528
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
@@ -117,19 +120,38 @@ def __init__(
117120
self.config = config or TIMDEXDatasetConfig()
118121
self.location = location
119122

123+
self.create_data_structure()
124+
120125
# pyarrow dataset
121-
self.filesystem, self.paths = self.parse_location(self.location)
126+
self.filesystem, self.paths = self.parse_location(self.data_records_root)
122127
self.dataset: ds.Dataset = None # type: ignore[assignment]
123128
self.schema = TIMDEX_DATASET_SCHEMA
124129
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
125130

126-
# writing
127-
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
131+
# dataset metadata
132+
self.metadata = TIMDEXDatasetMetadata(location) # type: ignore[arg-type]
133+
134+
@property
135+
def location_scheme(self) -> Literal["file", "s3"]:
136+
scheme = urlparse(self.location).scheme # type: ignore[arg-type]
137+
if scheme == "":
138+
return "file"
139+
if scheme == "s3":
140+
return "s3"
141+
raise ValueError(f"Location with scheme type '{scheme}' not supported.")
128142

129143
@property
130144
def data_records_root(self) -> str:
131145
return f"{self.location.removesuffix('/')}/data/records" # type: ignore[union-attr]
132146

147+
def create_data_structure(self) -> None:
148+
"""Ensure ETL records data structure exists in TIMDEX dataset."""
149+
if self.location_scheme == "file":
150+
Path(self.data_records_root).mkdir(
151+
parents=True,
152+
exist_ok=True,
153+
)
154+
133155
@property
134156
def row_count(self) -> int:
135157
"""Get row count from loaded dataset."""
@@ -163,7 +185,7 @@ def load(
163185
start_time = time.perf_counter()
164186

165187
# reset paths from original location before load
166-
_, self.paths = self.parse_location(self.location)
188+
_, self.paths = self.parse_location(self.data_records_root)
167189

168190
# perform initial load of full dataset
169191
self.dataset = self._load_pyarrow_dataset()
@@ -172,7 +194,7 @@ def load(
172194
self.dataset = self._get_filtered_dataset(**filters)
173195

174196
logger.info(
175-
f"Dataset successfully loaded: '{self.location}', "
197+
f"Dataset successfully loaded: '{self.data_records_root}', "
176198
f"{round(time.perf_counter()-start_time, 2)}s"
177199
)
178200

@@ -298,6 +320,7 @@ def get_s3_filesystem() -> fs.FileSystem:
298320
session_token=credentials.token,
299321
)
300322

323+
# NOTE: WIP: this will be heavily reworked in upcoming .load() updates
301324
@classmethod
302325
def parse_location(
303326
cls,
@@ -315,6 +338,7 @@ def parse_location(
315338
case _:
316339
raise TypeError("Location type must be str or list[str].")
317340

341+
# NOTE: WIP: these will be removed in upcoming .load() updates
318342
@classmethod
319343
def _parse_single_location(
320344
cls, location: str
@@ -328,6 +352,7 @@ def _parse_single_location(
328352
source = location
329353
return filesystem, source
330354

355+
# NOTE: WIP: these will be removed in upcoming .load() updates
331356
@classmethod
332357
def _parse_multiple_locations(
333358
cls, location: list[str]
@@ -348,6 +373,7 @@ def write(
348373
records_iter: Iterator["DatasetRecord"],
349374
*,
350375
use_threads: bool = True,
376+
write_append_deltas: bool = True,
351377
) -> list[ds.WrittenFile]:
352378
"""Write records to the TIMDEX parquet dataset.
353379
@@ -370,25 +396,27 @@ def write(
370396
Args:
371397
- records_iter: Iterator of DatasetRecord instances
372398
- use_threads: boolean if threads should be used for writing
399+
- write_append_deltas: boolean if append deltas should be written for records
400+
written during write
373401
"""
374402
start_time = time.perf_counter()
375-
self._written_files = []
403+
written_files: list[ds.WrittenFile] = []
376404

377405
dataset_filesystem, dataset_path = self.parse_location(self.data_records_root)
378406
if isinstance(dataset_path, list):
379407
raise TypeError(
380408
"Dataset location must be the root of a single dataset for writing"
381409
)
382410

411+
# write ETL parquet records
383412
record_batches_iter = self.create_record_batches(records_iter)
384-
385413
ds.write_dataset(
386414
record_batches_iter,
387415
base_dir=dataset_path,
388416
basename_template="%s-{i}.parquet" % (str(uuid.uuid4())), # noqa: UP031
389417
existing_data_behavior="overwrite_or_ignore",
390418
filesystem=dataset_filesystem,
391-
file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type]
419+
file_visitor=lambda written_file: written_files.append(written_file), # type: ignore[arg-type]
392420
format="parquet",
393421
max_open_files=500,
394422
max_rows_per_file=self.config.max_rows_per_file,
@@ -399,8 +427,14 @@ def write(
399427
use_threads=use_threads,
400428
)
401429

402-
self.log_write_statistics(start_time)
403-
return self._written_files # type: ignore[return-value]
430+
# write metadata append deltas
431+
if write_append_deltas:
432+
for written_file in written_files:
433+
self.metadata.write_append_delta_duckdb(written_file.path) # type: ignore[attr-defined]
434+
435+
self.log_write_statistics(start_time, written_files)
436+
437+
return written_files
404438

405439
def create_record_batches(
406440
self, records_iter: Iterator["DatasetRecord"]
@@ -423,19 +457,18 @@ def create_record_batches(
423457
logger.debug(f"Yielding batch {i + 1} for dataset writing.")
424458
yield batch
425459

426-
def log_write_statistics(self, start_time: float) -> None:
460+
def log_write_statistics(
461+
self,
462+
start_time: float,
463+
written_files: list[ds.WrittenFile],
464+
) -> None:
427465
"""Parse written files from write and log statistics."""
428466
total_time = round(time.perf_counter() - start_time, 2)
429-
total_files = len(self._written_files)
467+
total_files = len(written_files)
430468
total_rows = sum(
431-
[
432-
wf.metadata.num_rows # type: ignore[attr-defined]
433-
for wf in self._written_files
434-
]
435-
)
436-
total_size = sum(
437-
[wf.size for wf in self._written_files] # type: ignore[attr-defined]
469+
[wf.metadata.num_rows for wf in written_files] # type: ignore[attr-defined]
438470
)
471+
total_size = sum([wf.size for wf in written_files]) # type: ignore[attr-defined]
439472
logger.info(
440473
f"Dataset write complete - elapsed: "
441474
f"{total_time}s, "

0 commit comments

Comments
 (0)