Skip to content

Commit c59f2d1

Browse files
committed
Refactor class and duckdb connection relationships
Why these changes are being introduced: This refactoring work was a long time coming, inspired by a recent need to gracefully handle a read request for embeddings against a dataset without embeddings parquet files. If we can normalize how and when tables are created, and the handling of duckdb connections, we can normalize handling read requests for data that may not be available (yet). As such, this refactoring work will help normalize read edge cases now and going forward. This library was built in stages. First was TIMDEXDataset, which read parquet files directly. Then TIMDEXDatasetMetadata, which more formally introduced DuckDB. It handled the connection creation and configuration. This connection was shared with TIMDEXDataset as we leaned into DuckDB reading. Lastly, TIMDEXEmbeddings was added as our first new "source" of data. This class shared the connection from TIMDEXDataset. Both TIMDEXDatasetMetadata and TIMDEXEmbeddings were doing their own SQLAlchemy table reflections. TIMDEXDatasetMetadata could be instantiated on its own, while TIMDEXEmbeddings was assumed to take an instance of TIMDEXDataset. At this point, while things worked, it was clear that a refactor would be beneficial. We needed clearer responsibility of what created and configured the DuckDB connection, solidify that TIMDEXDatasetMetadata and TIMDEXEmbeddings are components on TIMDEXDataset, and how and when SQLAlchemy reflection was performed. Aligning all these things will make responding to these read and write edge cases much easier. How this addresses that need: - A new factory class is created DuckDBConnectionFactory that is responsible for creating and configuring any DuckDB connections used. - Both TIMDEXDatasetMetadata and TIMDEXEmbeddings require a TIMDEXDataset instance, and then themselves become components on TIMDEXDataset. We can more accurately call them "components" then of the primary TIMDEXDataset. - TIMDEXDataset handles the creation of a DuckDB connection via the new factory, and this connection is then accesible to its components TIMDEXDatasetMetadata and TIMDEXEmbeddings (maybe more in the future) - TIMDEXDataset is also responsible for all SQLAlchemy reflection, saving to self.sa_tables. In this way, any component that may want a SQLAlchemy instance, e.g. for reading, it can get it from `self.timdex_dataset.get_sa_table(<schema>, <table)`. - Refreshing of classes is greatly simplifed: TIMDEXDataset is the true orchestrator now, so a full re-init satisfies this. Components no longer have their own `.refresh()` methods. - Where possible, update all tests to use components like TIMDEXEmbeddings as part of a TIMDEXDataset intsance, not a long class instance. Side effects of this change: * It is not recommended to use TIMDEXDatasetMetadata or TIMDEXEmbeddings by themselves; they are meant as components on a TIMDEXDataset. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-306
1 parent 21a0256 commit c59f2d1

8 files changed

Lines changed: 276 additions & 288 deletions

File tree

tests/conftest.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
generate_sample_embeddings_for_run,
1313
generate_sample_records,
1414
)
15-
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
15+
from timdex_dataset_api import TIMDEXDataset
1616
from timdex_dataset_api.dataset import TIMDEXDatasetConfig
1717
from timdex_dataset_api.embeddings import (
1818
DatasetEmbedding,
1919
TIMDEXEmbeddings,
2020
)
21+
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
2122
from timdex_dataset_api.record import DatasetRecord
2223

2324

@@ -230,10 +231,8 @@ def timdex_dataset_same_day_runs(tmp_path) -> TIMDEXDataset:
230231
@pytest.fixture(scope="module")
231232
def timdex_metadata(timdex_dataset_with_runs) -> TIMDEXDatasetMetadata:
232233
"""TIMDEXDatasetMetadata with static database file created."""
233-
metadata = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
234-
metadata.rebuild_dataset_metadata()
235-
metadata.refresh()
236-
return metadata
234+
timdex_dataset_with_runs.metadata.rebuild_dataset_metadata()
235+
return timdex_dataset_with_runs.metadata
237236

238237

239238
@pytest.fixture(scope="module")
@@ -247,9 +246,9 @@ def timdex_dataset_with_runs_with_metadata(
247246

248247

249248
@pytest.fixture
250-
def timdex_metadata_empty(timdex_dataset_with_runs) -> TIMDEXDatasetMetadata:
249+
def timdex_metadata_empty(timdex_dataset_empty) -> TIMDEXDatasetMetadata:
251250
"""TIMDEXDatasetMetadata without static database file."""
252-
return TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
251+
return timdex_dataset_empty.metadata
253252

254253

255254
@pytest.fixture
@@ -271,7 +270,8 @@ def timdex_metadata_with_deltas(
271270
)
272271
td.write(records)
273272

274-
return TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
273+
# return fresh TIMDEXDataset's metadata
274+
return TIMDEXDataset(timdex_dataset_with_runs.location).metadata
275275

276276

277277
@pytest.fixture
@@ -286,12 +286,11 @@ def timdex_metadata_merged_deltas(
286286
# clone dataset with runs using new dataset location
287287
td = TIMDEXDataset(dataset_location, config=timdex_dataset_with_runs.config)
288288

289-
# clone metadata and merge append deltas
290-
metadata = TIMDEXDatasetMetadata(td.location)
291-
metadata.merge_append_deltas()
292-
metadata.refresh()
289+
# merge append deltas via the TD's metadata
290+
td.metadata.merge_append_deltas()
291+
td.refresh()
293292

294-
return metadata
293+
return td.metadata
295294

296295

297296
# ================================================================================

tests/test_embeddings.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,13 @@ def test_embeddings_read_batches_yields_pyarrow_record_batches(
152152
timdex_dataset_empty.metadata.rebuild_dataset_metadata()
153153
timdex_dataset_empty.refresh()
154154

155-
# write embeddings
156-
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
157-
timdex_embeddings.write(sample_embeddings_generator(100, run_id="test-run"))
158-
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
155+
# write embeddings and refresh to pick up new views
156+
timdex_dataset_empty.embeddings.write(
157+
sample_embeddings_generator(100, run_id="test-run")
158+
)
159+
timdex_dataset_empty.refresh()
159160

160-
batches = timdex_embeddings.read_batches_iter()
161+
batches = timdex_dataset_empty.embeddings.read_batches_iter()
161162
batch = next(batches)
162163
assert isinstance(batch, pa.RecordBatch)
163164

tests/test_metadata.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from duckdb import DuckDBPyConnection
88

9-
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
9+
from timdex_dataset_api import TIMDEXDataset
1010

1111
ORDERED_METADATA_COLUMN_NAMES = [
1212
"timdex_record_id",
@@ -21,29 +21,33 @@
2121
]
2222

2323

24-
def test_tdm_init_no_metadata_file_warning_success(caplog, timdex_dataset_with_runs):
25-
TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
26-
24+
def test_tdm_init_no_metadata_file_warning_success(caplog, tmp_path):
25+
# creating a new TIMDEXDataset will log warning if no metadata file
26+
caplog.set_level("WARNING")
27+
TIMDEXDataset(str(tmp_path / "new_empty_dataset"))
2728
assert "Static metadata database not found" in caplog.text
2829

2930

3031
def test_tdm_local_dataset_structure_properties(tmp_path):
3132
local_root = str(Path(tmp_path) / "path/to/nothing")
32-
tdm_local = TIMDEXDatasetMetadata(local_root)
33-
assert tdm_local.location == local_root
34-
assert tdm_local.location_scheme == "file"
33+
td_local = TIMDEXDataset(local_root)
34+
assert td_local.metadata.location == local_root
35+
assert td_local.metadata.location_scheme == "file"
3536

3637

37-
def test_tdm_s3_dataset_structure_properties(s3_bucket_mocked):
38-
s3_root = "s3://timdex/dataset"
39-
tdm_s3 = TIMDEXDatasetMetadata(s3_root)
40-
assert tdm_s3.location == s3_root
41-
assert tdm_s3.location_scheme == "s3"
38+
def test_tdm_s3_dataset_structure_properties(timdex_dataset_empty):
39+
# test that location_scheme property works correctly for local paths
40+
# S3 tests require full mocking and are covered in other tests
41+
assert timdex_dataset_empty.metadata.location_scheme == "file"
4242

4343

44-
def test_tdm_create_metadata_database_file_success(caplog, timdex_metadata_empty):
44+
def test_tdm_create_metadata_database_file_success(
45+
caplog, timdex_dataset_with_runs, timdex_metadata_empty
46+
):
4547
caplog.set_level("DEBUG")
46-
timdex_metadata_empty.rebuild_dataset_metadata()
48+
# use a fresh dataset from timdex_dataset_with_runs location
49+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
50+
td.metadata.rebuild_dataset_metadata()
4751

4852

4953
def test_tdm_init_metadata_file_found_success(timdex_metadata):
@@ -321,15 +325,15 @@ def test_tdm_merge_append_deltas_deletes_append_deltas(
321325
assert not os.listdir(timdex_metadata_merged_deltas.append_deltas_path)
322326

323327

324-
def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
328+
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
325329
monkeypatch, tmp_path_factory, timdex_dataset_with_runs
326330
):
327331
preset_home = tmp_path_factory.mktemp("my-account")
328332
monkeypatch.setenv("HOME", str(preset_home))
329333

330-
tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
334+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
331335
df = (
332-
tdm.conn.query(
336+
td.conn.query(
333337
"""
334338
select
335339
current_setting('secret_directory') as secret_directory,
@@ -344,15 +348,15 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
344348
assert df.extension_directory == "" # expected and okay when HOME set
345349

346350

347-
def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_unset(
351+
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_unset(
348352
monkeypatch, timdex_dataset_with_runs
349353
):
350354
monkeypatch.delenv("HOME", raising=False)
351355

352-
tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
356+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
353357

354358
df = (
355-
tdm.conn.query(
359+
td.conn.query(
356360
"""
357361
select
358362
current_setting('secret_directory') as secret_directory,
@@ -367,15 +371,15 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_unset(
367371
assert df.extension_directory == "/tmp/.duckdb/extensions"
368372

369373

370-
def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
374+
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
371375
monkeypatch, timdex_dataset_with_runs
372376
):
373377
monkeypatch.setenv("HOME", "") # simulate AWS Lambda environment
374378

375-
tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
379+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
376380

377381
df = (
378-
tdm.conn.query(
382+
td.conn.query(
379383
"""
380384
select
381385
current_setting('secret_directory') as secret_directory,
@@ -390,14 +394,16 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
390394
assert df.extension_directory == "/tmp/.duckdb/extensions"
391395

392396

393-
def test_tdm_preload_current_records_default_false(tmp_path):
394-
tdm = TIMDEXDatasetMetadata(str(tmp_path))
395-
assert tdm.preload_current_records is False
397+
def test_td_preload_current_records_default_false(tmp_path):
398+
td = TIMDEXDataset(str(tmp_path))
399+
assert td.preload_current_records is False
400+
assert td.metadata.preload_current_records is False
396401

397402

398-
def test_tdm_preload_current_records_flag_true(tmp_path):
399-
tdm = TIMDEXDatasetMetadata(str(tmp_path), preload_current_records=True)
400-
assert tdm.preload_current_records is True
403+
def test_td_preload_current_records_flag_true(tmp_path):
404+
td = TIMDEXDataset(str(tmp_path), preload_current_records=True)
405+
assert td.preload_current_records is True
406+
assert td.metadata.preload_current_records is True
401407

402408

403409
def test_tdm_preload_false_no_temp_table(timdex_dataset_with_runs):

tests/test_read.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def test_dataset_load_current_records_gets_correct_same_day_full_run(
255255
):
256256
# ensure metadata exists for this dataset
257257
timdex_dataset_same_day_runs.metadata.rebuild_dataset_metadata()
258-
timdex_dataset_same_day_runs.metadata.refresh()
259258
df = timdex_dataset_same_day_runs.read_dataframe(
260259
table="current_records", run_type="full"
261260
)
@@ -266,7 +265,6 @@ def test_dataset_load_current_records_gets_correct_same_day_daily_runs_ordering(
266265
timdex_dataset_same_day_runs,
267266
):
268267
timdex_dataset_same_day_runs.metadata.rebuild_dataset_metadata()
269-
timdex_dataset_same_day_runs.metadata.refresh()
270268
first_record = next(
271269
timdex_dataset_same_day_runs.read_dicts_iter(
272270
table="current_records", run_type="daily"

timdex_dataset_api/dataset.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
import pandas as pd
1717
import pyarrow as pa
1818
import pyarrow.dataset as ds
19-
from duckdb import DuckDBPyConnection
19+
from duckdb_engine import ConnectionWrapper
2020
from pyarrow import fs
21+
from sqlalchemy import MetaData, Table, create_engine
22+
from sqlalchemy.types import ARRAY, FLOAT
2123

2224
from timdex_dataset_api.config import configure_logger
2325
from timdex_dataset_api.embeddings import TIMDEXEmbeddings
2426
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
27+
from timdex_dataset_api.utils import DuckDBConnectionFactory
2528

2629
if TYPE_CHECKING:
2730
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
@@ -78,6 +81,10 @@ class TIMDEXDatasetConfig:
7881
from a dataset; pyarrow default is 16
7982
- fragment_read_ahead: number of fragments to optimistically read ahead when batch
8083
reaching from a dataset; pyarrow default is 4
84+
- duckdb_join_batch_size: batch size for keyset pagination when joining metadata
85+
86+
Note: DuckDB connection settings (memory_limit, threads) are handled by
87+
DuckDBConnectionFactory via TDA_DUCKDB_MEMORY_LIMIT and TDA_DUCKDB_THREADS env vars.
8188
"""
8289

8390
read_batch_size: int = field(
@@ -132,18 +139,21 @@ def __init__(
132139
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
133140
self.dataset = self.load_pyarrow_dataset()
134141

135-
# dataset metadata
136-
self.metadata = TIMDEXDatasetMetadata(
137-
location,
138-
preload_current_records=preload_current_records,
139-
)
142+
# create DuckDB connection used by all classes
143+
self.conn_factory = DuckDBConnectionFactory(location_scheme=self.location_scheme)
144+
self.conn = self.conn_factory.create_connection()
140145

141-
# DuckDB context
142-
self.conn = self.setup_duckdb_context()
146+
# create schemas
147+
self._create_duckdb_schemas()
143148

144-
# dataset embeddings
149+
# composed components receive self
150+
self.metadata = TIMDEXDatasetMetadata(self)
145151
self.embeddings = TIMDEXEmbeddings(self)
146152

153+
# SQLAlchemy (SA) reflection after components have set up their views
154+
self.sa_tables: dict[str, dict[str, Table]] = {}
155+
self.reflect_sa_tables()
156+
147157
@property
148158
def location_scheme(self) -> Literal["file", "s3"]:
149159
scheme = urlparse(self.location).scheme
@@ -158,7 +168,7 @@ def data_records_root(self) -> str:
158168
return f"{self.location.removesuffix('/')}/data/records" # type: ignore[union-attr]
159169

160170
def refresh(self) -> None:
161-
"""Fully reload TIMDEXDataset instance."""
171+
"""Refresh dataset by fully reinitializing."""
162172
self.__init__( # type: ignore[misc]
163173
self.location,
164174
config=self.config,
@@ -245,24 +255,54 @@ def get_s3_filesystem() -> fs.FileSystem:
245255
session_token=credentials.token,
246256
)
247257

248-
def setup_duckdb_context(self) -> DuckDBPyConnection:
249-
"""Create a DuckDB connection that metadata and data query and retrieval.
258+
def _create_duckdb_schemas(self) -> None:
259+
"""Create DuckDB schemas used by all components."""
260+
self.conn.execute("create schema metadata;")
261+
self.conn.execute("create schema data;")
250262

251-
This method extends TIMDEXDatasetMetadata's pre-existing DuckDB connection, adding
252-
a 'data' schema and any other configurations needed.
263+
def reflect_sa_tables(self, schemas: list[str] | None = None) -> None:
264+
"""Reflect SQLAlchemy metadata for DuckDB schemas.
265+
266+
This centralizes SA reflection for all composed components. Reflected tables
267+
are stored in self.sa_tables as {schema: {table_name: Table}}.
268+
269+
Args:
270+
schemas: list of schemas to reflect; defaults to ["metadata", "data"]
253271
"""
254272
start_time = time.perf_counter()
273+
schemas = schemas or ["metadata", "data"]
255274

256-
conn = self.metadata.conn
275+
engine = create_engine(
276+
"duckdb://",
277+
creator=lambda: ConnectionWrapper(self.conn),
278+
)
279+
280+
for schema in schemas:
281+
db_metadata = MetaData()
282+
db_metadata.reflect(bind=engine, schema=schema, views=True)
283+
284+
# store tables in flat dict keyed by table name (without schema prefix)
285+
self.sa_tables[schema] = {
286+
table_name.removeprefix(f"{schema}."): table
287+
for table_name, table in db_metadata.tables.items()
288+
}
257289

258-
# create data schema
259-
conn.execute("""create schema data;""")
290+
# type fixup for embedding_vector column (DuckDB LIST -> SA ARRAY)
291+
if "embeddings" in self.sa_tables.get("data", {}):
292+
self.sa_tables["data"]["embeddings"].c.embedding_vector.type = ARRAY(FLOAT)
260293

261294
logger.debug(
262-
"DuckDB context created for TIMDEXDataset, "
263-
f"{round(time.perf_counter()-start_time,2)}s"
295+
f"SQLAlchemy reflection complete for schemas {schemas}, "
296+
f"{round(time.perf_counter() - start_time, 3)}s"
264297
)
265-
return conn
298+
299+
def get_sa_table(self, schema: str, table: str) -> Table:
300+
"""Get a reflected SQLAlchemy Table by schema and table name."""
301+
if schema not in self.sa_tables:
302+
raise ValueError(f"Schema '{schema}' not found in reflected SA tables.")
303+
if table not in self.sa_tables[schema]:
304+
raise ValueError(f"Table '{table}' not found in schema '{schema}'.")
305+
return self.sa_tables[schema][table]
266306

267307
def write(
268308
self,
@@ -326,7 +366,7 @@ def write(
326366
if write_append_deltas:
327367
for written_file in written_files:
328368
self.metadata.write_append_delta_duckdb(written_file.path) # type: ignore[attr-defined]
329-
self.metadata.refresh()
369+
self.refresh()
330370

331371
self.log_write_statistics(start_time, written_files)
332372

@@ -575,9 +615,7 @@ def _iter_data_chunks(self, data_query: str) -> Iterator[pa.RecordBatch]:
575615
)
576616
finally:
577617
if self.location_scheme == "s3":
578-
self.conn.execute(
579-
f"""set threads={self.metadata.config.duckdb_connection_threads};"""
580-
)
618+
self.conn.execute(f"""set threads={self.conn_factory.threads};""")
581619

582620
def read_dataframes_iter(
583621
self,

0 commit comments

Comments
 (0)