1616import pandas as pd
1717import pyarrow as pa
1818import pyarrow .dataset as ds
19- from duckdb import DuckDBPyConnection
19+ from duckdb_engine import ConnectionWrapper
2020from pyarrow import fs
21+ from sqlalchemy import MetaData , Table , create_engine
22+ from sqlalchemy .types import ARRAY , FLOAT
2123
2224from timdex_dataset_api .config import configure_logger
2325from timdex_dataset_api .embeddings import TIMDEXEmbeddings
2426from timdex_dataset_api .metadata import TIMDEXDatasetMetadata
27+ from timdex_dataset_api .utils import DuckDBConnectionFactory
2528
2629if TYPE_CHECKING :
2730 from timdex_dataset_api .record import DatasetRecord # pragma: nocover
@@ -78,6 +81,10 @@ class TIMDEXDatasetConfig:
7881 from a dataset; pyarrow default is 16
7982 - fragment_read_ahead: number of fragments to optimistically read ahead when batch
8083 reaching from a dataset; pyarrow default is 4
84+ - duckdb_join_batch_size: batch size for keyset pagination when joining metadata
85+
86+ Note: DuckDB connection settings (memory_limit, threads) are handled by
87+ DuckDBConnectionFactory via TDA_DUCKDB_MEMORY_LIMIT and TDA_DUCKDB_THREADS env vars.
8188 """
8289
8390 read_batch_size : int = field (
@@ -132,18 +139,21 @@ def __init__(
132139 self .partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
133140 self .dataset = self .load_pyarrow_dataset ()
134141
135- # dataset metadata
136- self .metadata = TIMDEXDatasetMetadata (
137- location ,
138- preload_current_records = preload_current_records ,
139- )
142+ # create DuckDB connection used by all classes
143+ self .conn_factory = DuckDBConnectionFactory (location_scheme = self .location_scheme )
144+ self .conn = self .conn_factory .create_connection ()
140145
141- # DuckDB context
142- self .conn = self . setup_duckdb_context ()
146+ # create schemas
147+ self ._create_duckdb_schemas ()
143148
144- # dataset embeddings
149+ # composed components receive self
150+ self .metadata = TIMDEXDatasetMetadata (self )
145151 self .embeddings = TIMDEXEmbeddings (self )
146152
153+ # SQLAlchemy (SA) reflection after components have set up their views
154+ self .sa_tables : dict [str , dict [str , Table ]] = {}
155+ self .reflect_sa_tables ()
156+
147157 @property
148158 def location_scheme (self ) -> Literal ["file" , "s3" ]:
149159 scheme = urlparse (self .location ).scheme
@@ -158,7 +168,7 @@ def data_records_root(self) -> str:
158168 return f"{ self .location .removesuffix ('/' )} /data/records" # type: ignore[union-attr]
159169
160170 def refresh (self ) -> None :
161- """Fully reload TIMDEXDataset instance ."""
171+ """Refresh dataset by fully reinitializing ."""
162172 self .__init__ ( # type: ignore[misc]
163173 self .location ,
164174 config = self .config ,
@@ -245,24 +255,54 @@ def get_s3_filesystem() -> fs.FileSystem:
245255 session_token = credentials .token ,
246256 )
247257
248- def setup_duckdb_context (self ) -> DuckDBPyConnection :
249- """Create a DuckDB connection that metadata and data query and retrieval.
258+ def _create_duckdb_schemas (self ) -> None :
259+ """Create DuckDB schemas used by all components."""
260+ self .conn .execute ("create schema metadata;" )
261+ self .conn .execute ("create schema data;" )
250262
251- This method extends TIMDEXDatasetMetadata's pre-existing DuckDB connection, adding
252- a 'data' schema and any other configurations needed.
263+ def reflect_sa_tables (self , schemas : list [str ] | None = None ) -> None :
264+ """Reflect SQLAlchemy metadata for DuckDB schemas.
265+
266+ This centralizes SA reflection for all composed components. Reflected tables
267+ are stored in self.sa_tables as {schema: {table_name: Table}}.
268+
269+ Args:
270+ schemas: list of schemas to reflect; defaults to ["metadata", "data"]
253271 """
254272 start_time = time .perf_counter ()
273+ schemas = schemas or ["metadata" , "data" ]
255274
256- conn = self .metadata .conn
275+ engine = create_engine (
276+ "duckdb://" ,
277+ creator = lambda : ConnectionWrapper (self .conn ),
278+ )
279+
280+ for schema in schemas :
281+ db_metadata = MetaData ()
282+ db_metadata .reflect (bind = engine , schema = schema , views = True )
283+
284+ # store tables in flat dict keyed by table name (without schema prefix)
285+ self .sa_tables [schema ] = {
286+ table_name .removeprefix (f"{ schema } ." ): table
287+ for table_name , table in db_metadata .tables .items ()
288+ }
257289
258- # create data schema
259- conn .execute ("""create schema data;""" )
290+ # type fixup for embedding_vector column (DuckDB LIST -> SA ARRAY)
291+ if "embeddings" in self .sa_tables .get ("data" , {}):
292+ self .sa_tables ["data" ]["embeddings" ].c .embedding_vector .type = ARRAY (FLOAT )
260293
261294 logger .debug (
262- "DuckDB context created for TIMDEXDataset , "
263- f"{ round (time .perf_counter ()- start_time ,2 )} s"
295+ f"SQLAlchemy reflection complete for schemas { schemas } , "
296+ f"{ round (time .perf_counter () - start_time , 3 )} s"
264297 )
265- return conn
298+
299+ def get_sa_table (self , schema : str , table : str ) -> Table :
300+ """Get a reflected SQLAlchemy Table by schema and table name."""
301+ if schema not in self .sa_tables :
302+ raise ValueError (f"Schema '{ schema } ' not found in reflected SA tables." )
303+ if table not in self .sa_tables [schema ]:
304+ raise ValueError (f"Table '{ table } ' not found in schema '{ schema } '." )
305+ return self .sa_tables [schema ][table ]
266306
267307 def write (
268308 self ,
@@ -326,7 +366,7 @@ def write(
326366 if write_append_deltas :
327367 for written_file in written_files :
328368 self .metadata .write_append_delta_duckdb (written_file .path ) # type: ignore[attr-defined]
329- self .metadata . refresh ()
369+ self .refresh ()
330370
331371 self .log_write_statistics (start_time , written_files )
332372
@@ -575,9 +615,7 @@ def _iter_data_chunks(self, data_query: str) -> Iterator[pa.RecordBatch]:
575615 )
576616 finally :
577617 if self .location_scheme == "s3" :
578- self .conn .execute (
579- f"""set threads={ self .metadata .config .duckdb_connection_threads } ;"""
580- )
618+ self .conn .execute (f"""set threads={ self .conn_factory .threads } ;""" )
581619
582620 def read_dataframes_iter (
583621 self ,
0 commit comments