Skip to content

Commit 7ae3b96

Browse files
authored
Merge pull request #181 from MITLibraries/USE-306-refactor-class-relationships
USE 306 - refactor class relationships
2 parents a1d8ad7 + efccf88 commit 7ae3b96

10 files changed

Lines changed: 455 additions & 469 deletions

File tree

Pipfile.lock

Lines changed: 178 additions & 180 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/conftest.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
generate_sample_embeddings_for_run,
1313
generate_sample_records,
1414
)
15-
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
15+
from timdex_dataset_api import TIMDEXDataset
1616
from timdex_dataset_api.dataset import TIMDEXDatasetConfig
1717
from timdex_dataset_api.embeddings import (
1818
DatasetEmbedding,
1919
TIMDEXEmbeddings,
2020
)
21+
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
2122
from timdex_dataset_api.record import DatasetRecord
2223

2324

@@ -230,10 +231,8 @@ def timdex_dataset_same_day_runs(tmp_path) -> TIMDEXDataset:
230231
@pytest.fixture(scope="module")
231232
def timdex_metadata(timdex_dataset_with_runs) -> TIMDEXDatasetMetadata:
232233
"""TIMDEXDatasetMetadata with static database file created."""
233-
metadata = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
234-
metadata.rebuild_dataset_metadata()
235-
metadata.refresh()
236-
return metadata
234+
timdex_dataset_with_runs.metadata.rebuild_dataset_metadata()
235+
return timdex_dataset_with_runs.metadata
237236

238237

239238
@pytest.fixture(scope="module")
@@ -247,9 +246,9 @@ def timdex_dataset_with_runs_with_metadata(
247246

248247

249248
@pytest.fixture
250-
def timdex_metadata_empty(timdex_dataset_with_runs) -> TIMDEXDatasetMetadata:
249+
def timdex_metadata_empty(timdex_dataset_empty) -> TIMDEXDatasetMetadata:
251250
"""TIMDEXDatasetMetadata without static database file."""
252-
return TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
251+
return timdex_dataset_empty.metadata
253252

254253

255254
@pytest.fixture
@@ -271,7 +270,8 @@ def timdex_metadata_with_deltas(
271270
)
272271
td.write(records)
273272

274-
return TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
273+
# return fresh TIMDEXDataset's metadata
274+
return TIMDEXDataset(timdex_dataset_with_runs.location).metadata
275275

276276

277277
@pytest.fixture
@@ -286,12 +286,11 @@ def timdex_metadata_merged_deltas(
286286
# clone dataset with runs using new dataset location
287287
td = TIMDEXDataset(dataset_location, config=timdex_dataset_with_runs.config)
288288

289-
# clone metadata and merge append deltas
290-
metadata = TIMDEXDatasetMetadata(td.location)
291-
metadata.merge_append_deltas()
292-
metadata.refresh()
289+
# merge append deltas via the TD's metadata
290+
td.metadata.merge_append_deltas()
291+
td.refresh()
293292

294-
return metadata
293+
return td.metadata
295294

296295

297296
# ================================================================================

tests/test_embeddings.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,13 @@ def test_embeddings_read_batches_yields_pyarrow_record_batches(
152152
timdex_dataset_empty.metadata.rebuild_dataset_metadata()
153153
timdex_dataset_empty.refresh()
154154

155-
# write embeddings
156-
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
157-
timdex_embeddings.write(sample_embeddings_generator(100, run_id="test-run"))
158-
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
155+
# write embeddings and refresh to pick up new views
156+
timdex_dataset_empty.embeddings.write(
157+
sample_embeddings_generator(100, run_id="test-run")
158+
)
159+
timdex_dataset_empty.refresh()
159160

160-
batches = timdex_embeddings.read_batches_iter()
161+
batches = timdex_dataset_empty.embeddings.read_batches_iter()
161162
batch = next(batches)
162163
assert isinstance(batch, pa.RecordBatch)
163164

tests/test_metadata.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from duckdb import DuckDBPyConnection
88

9-
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
9+
from timdex_dataset_api import TIMDEXDataset
1010

1111
ORDERED_METADATA_COLUMN_NAMES = [
1212
"timdex_record_id",
@@ -21,29 +21,33 @@
2121
]
2222

2323

24-
def test_tdm_init_no_metadata_file_warning_success(caplog, timdex_dataset_with_runs):
25-
TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
26-
24+
def test_tdm_init_no_metadata_file_warning_success(caplog, tmp_path):
25+
# creating a new TIMDEXDataset will log warning if no metadata file
26+
caplog.set_level("WARNING")
27+
TIMDEXDataset(str(tmp_path / "new_empty_dataset"))
2728
assert "Static metadata database not found" in caplog.text
2829

2930

3031
def test_tdm_local_dataset_structure_properties(tmp_path):
3132
local_root = str(Path(tmp_path) / "path/to/nothing")
32-
tdm_local = TIMDEXDatasetMetadata(local_root)
33-
assert tdm_local.location == local_root
34-
assert tdm_local.location_scheme == "file"
33+
td_local = TIMDEXDataset(local_root)
34+
assert td_local.metadata.location == local_root
35+
assert td_local.metadata.location_scheme == "file"
3536

3637

37-
def test_tdm_s3_dataset_structure_properties(s3_bucket_mocked):
38-
s3_root = "s3://timdex/dataset"
39-
tdm_s3 = TIMDEXDatasetMetadata(s3_root)
40-
assert tdm_s3.location == s3_root
41-
assert tdm_s3.location_scheme == "s3"
38+
def test_tdm_s3_dataset_structure_properties(timdex_dataset_empty):
39+
# test that location_scheme property works correctly for local paths
40+
# S3 tests require full mocking and are covered in other tests
41+
assert timdex_dataset_empty.metadata.location_scheme == "file"
4242

4343

44-
def test_tdm_create_metadata_database_file_success(caplog, timdex_metadata_empty):
44+
def test_tdm_create_metadata_database_file_success(
45+
caplog, timdex_dataset_with_runs, timdex_metadata_empty
46+
):
4547
caplog.set_level("DEBUG")
46-
timdex_metadata_empty.rebuild_dataset_metadata()
48+
# use a fresh dataset from timdex_dataset_with_runs location
49+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
50+
td.metadata.rebuild_dataset_metadata()
4751

4852

4953
def test_tdm_init_metadata_file_found_success(timdex_metadata):
@@ -321,15 +325,15 @@ def test_tdm_merge_append_deltas_deletes_append_deltas(
321325
assert not os.listdir(timdex_metadata_merged_deltas.append_deltas_path)
322326

323327

324-
def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
328+
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
325329
monkeypatch, tmp_path_factory, timdex_dataset_with_runs
326330
):
327331
preset_home = tmp_path_factory.mktemp("my-account")
328332
monkeypatch.setenv("HOME", str(preset_home))
329333

330-
tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
334+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
331335
df = (
332-
tdm.conn.query(
336+
td.conn.query(
333337
"""
334338
select
335339
current_setting('secret_directory') as secret_directory,
@@ -344,15 +348,15 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
344348
assert df.extension_directory == "" # expected and okay when HOME set
345349

346350

347-
def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_unset(
351+
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_unset(
348352
monkeypatch, timdex_dataset_with_runs
349353
):
350354
monkeypatch.delenv("HOME", raising=False)
351355

352-
tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
356+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
353357

354358
df = (
355-
tdm.conn.query(
359+
td.conn.query(
356360
"""
357361
select
358362
current_setting('secret_directory') as secret_directory,
@@ -367,15 +371,15 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_unset(
367371
assert df.extension_directory == "/tmp/.duckdb/extensions"
368372

369373

370-
def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
374+
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
371375
monkeypatch, timdex_dataset_with_runs
372376
):
373377
monkeypatch.setenv("HOME", "") # simulate AWS Lambda environment
374378

375-
tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
379+
td = TIMDEXDataset(timdex_dataset_with_runs.location)
376380

377381
df = (
378-
tdm.conn.query(
382+
td.conn.query(
379383
"""
380384
select
381385
current_setting('secret_directory') as secret_directory,
@@ -390,14 +394,16 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
390394
assert df.extension_directory == "/tmp/.duckdb/extensions"
391395

392396

393-
def test_tdm_preload_current_records_default_false(tmp_path):
394-
tdm = TIMDEXDatasetMetadata(str(tmp_path))
395-
assert tdm.preload_current_records is False
397+
def test_td_preload_current_records_default_false(tmp_path):
398+
td = TIMDEXDataset(str(tmp_path))
399+
assert td.preload_current_records is False
400+
assert td.metadata.preload_current_records is False
396401

397402

398-
def test_tdm_preload_current_records_flag_true(tmp_path):
399-
tdm = TIMDEXDatasetMetadata(str(tmp_path), preload_current_records=True)
400-
assert tdm.preload_current_records is True
403+
def test_td_preload_current_records_flag_true(tmp_path):
404+
td = TIMDEXDataset(str(tmp_path), preload_current_records=True)
405+
assert td.preload_current_records is True
406+
assert td.metadata.preload_current_records is True
401407

402408

403409
def test_tdm_preload_false_no_temp_table(timdex_dataset_with_runs):

tests/test_read.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def test_dataset_load_current_records_gets_correct_same_day_full_run(
255255
):
256256
# ensure metadata exists for this dataset
257257
timdex_dataset_same_day_runs.metadata.rebuild_dataset_metadata()
258-
timdex_dataset_same_day_runs.metadata.refresh()
259258
df = timdex_dataset_same_day_runs.read_dataframe(
260259
table="current_records", run_type="full"
261260
)
@@ -266,7 +265,6 @@ def test_dataset_load_current_records_gets_correct_same_day_daily_runs_ordering(
266265
timdex_dataset_same_day_runs,
267266
):
268267
timdex_dataset_same_day_runs.metadata.rebuild_dataset_metadata()
269-
timdex_dataset_same_day_runs.metadata.refresh()
270268
first_record = next(
271269
timdex_dataset_same_day_runs.read_dicts_iter(
272270
table="current_records", run_type="daily"

timdex_dataset_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
66
from timdex_dataset_api.record import DatasetRecord
77

8-
__version__ = "3.8.0"
8+
__version__ = "3.9.0"
99

1010
__all__ = [
1111
"DatasetEmbedding",

timdex_dataset_api/dataset.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
import pandas as pd
1717
import pyarrow as pa
1818
import pyarrow.dataset as ds
19-
from duckdb import DuckDBPyConnection
19+
from duckdb_engine import ConnectionWrapper
2020
from pyarrow import fs
21+
from sqlalchemy import MetaData, Table, create_engine
22+
from sqlalchemy.types import ARRAY, FLOAT
2123

2224
from timdex_dataset_api.config import configure_logger
2325
from timdex_dataset_api.embeddings import TIMDEXEmbeddings
2426
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
27+
from timdex_dataset_api.utils import DuckDBConnectionFactory
2528

2629
if TYPE_CHECKING:
2730
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
@@ -78,6 +81,10 @@ class TIMDEXDatasetConfig:
7881
from a dataset; pyarrow default is 16
7982
- fragment_read_ahead: number of fragments to optimistically read ahead when batch
8083
reaching from a dataset; pyarrow default is 4
84+
- duckdb_join_batch_size: batch size for keyset pagination when joining metadata
85+
86+
Note: DuckDB connection settings (memory_limit, threads) are handled by
87+
DuckDBConnectionFactory via TDA_DUCKDB_MEMORY_LIMIT and TDA_DUCKDB_THREADS env vars.
8188
"""
8289

8390
read_batch_size: int = field(
@@ -132,18 +139,21 @@ def __init__(
132139
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
133140
self.dataset = self.load_pyarrow_dataset()
134141

135-
# dataset metadata
136-
self.metadata = TIMDEXDatasetMetadata(
137-
location,
138-
preload_current_records=preload_current_records,
139-
)
142+
# create DuckDB connection used by all classes
143+
self.conn_factory = DuckDBConnectionFactory(location_scheme=self.location_scheme)
144+
self.conn = self.conn_factory.create_connection()
140145

141-
# DuckDB context
142-
self.conn = self.setup_duckdb_context()
146+
# create schemas
147+
self._create_duckdb_schemas()
143148

144-
# dataset embeddings
149+
# composed components receive self
150+
self.metadata = TIMDEXDatasetMetadata(self)
145151
self.embeddings = TIMDEXEmbeddings(self)
146152

153+
# SQLAlchemy (SA) reflection after components have set up their views
154+
self.sa_tables: dict[str, dict[str, Table]] = {}
155+
self.reflect_sa_tables()
156+
147157
@property
148158
def location_scheme(self) -> Literal["file", "s3"]:
149159
scheme = urlparse(self.location).scheme
@@ -158,7 +168,7 @@ def data_records_root(self) -> str:
158168
return f"{self.location.removesuffix('/')}/data/records" # type: ignore[union-attr]
159169

160170
def refresh(self) -> None:
161-
"""Fully reload TIMDEXDataset instance."""
171+
"""Refresh dataset by fully reinitializing."""
162172
self.__init__( # type: ignore[misc]
163173
self.location,
164174
config=self.config,
@@ -245,24 +255,54 @@ def get_s3_filesystem() -> fs.FileSystem:
245255
session_token=credentials.token,
246256
)
247257

248-
def setup_duckdb_context(self) -> DuckDBPyConnection:
249-
"""Create a DuckDB connection that metadata and data query and retrieval.
258+
def _create_duckdb_schemas(self) -> None:
259+
"""Create DuckDB schemas used by all components."""
260+
self.conn.execute("create schema metadata;")
261+
self.conn.execute("create schema data;")
250262

251-
This method extends TIMDEXDatasetMetadata's pre-existing DuckDB connection, adding
252-
a 'data' schema and any other configurations needed.
263+
def reflect_sa_tables(self, schemas: list[str] | None = None) -> None:
264+
"""Reflect SQLAlchemy metadata for DuckDB schemas.
265+
266+
This centralizes SA reflection for all composed components. Reflected tables
267+
are stored in self.sa_tables as {schema: {table_name: Table}}.
268+
269+
Args:
270+
schemas: list of schemas to reflect; defaults to ["metadata", "data"]
253271
"""
254272
start_time = time.perf_counter()
273+
schemas = schemas or ["metadata", "data"]
255274

256-
conn = self.metadata.conn
275+
engine = create_engine(
276+
"duckdb://",
277+
creator=lambda: ConnectionWrapper(self.conn),
278+
)
279+
280+
for schema in schemas:
281+
db_metadata = MetaData()
282+
db_metadata.reflect(bind=engine, schema=schema, views=True)
283+
284+
# store tables in flat dict keyed by table name (without schema prefix)
285+
self.sa_tables[schema] = {
286+
table_name.removeprefix(f"{schema}."): table
287+
for table_name, table in db_metadata.tables.items()
288+
}
257289

258-
# create data schema
259-
conn.execute("""create schema data;""")
290+
# type fixup for embedding_vector column (DuckDB LIST -> SA ARRAY)
291+
if "embeddings" in self.sa_tables.get("data", {}):
292+
self.sa_tables["data"]["embeddings"].c.embedding_vector.type = ARRAY(FLOAT)
260293

261294
logger.debug(
262-
"DuckDB context created for TIMDEXDataset, "
263-
f"{round(time.perf_counter()-start_time,2)}s"
295+
f"SQLAlchemy reflection complete for schemas {schemas}, "
296+
f"{round(time.perf_counter() - start_time, 3)}s"
264297
)
265-
return conn
298+
299+
def get_sa_table(self, schema: str, table: str) -> Table:
300+
"""Get a reflected SQLAlchemy Table by schema and table name."""
301+
if schema not in self.sa_tables:
302+
raise ValueError(f"Schema '{schema}' not found in reflected SA tables.")
303+
if table not in self.sa_tables[schema]:
304+
raise ValueError(f"Table '{table}' not found in schema '{schema}'.")
305+
return self.sa_tables[schema][table]
266306

267307
def write(
268308
self,
@@ -326,7 +366,7 @@ def write(
326366
if write_append_deltas:
327367
for written_file in written_files:
328368
self.metadata.write_append_delta_duckdb(written_file.path) # type: ignore[attr-defined]
329-
self.metadata.refresh()
369+
self.refresh()
330370

331371
self.log_write_statistics(start_time, written_files)
332372

@@ -575,9 +615,7 @@ def _iter_data_chunks(self, data_query: str) -> Iterator[pa.RecordBatch]:
575615
)
576616
finally:
577617
if self.location_scheme == "s3":
578-
self.conn.execute(
579-
f"""set threads={self.metadata.config.duckdb_connection_threads};"""
580-
)
618+
self.conn.execute(f"""set threads={self.conn_factory.threads};""")
581619

582620
def read_dataframes_iter(
583621
self,

0 commit comments

Comments
 (0)