2828 pa .field ("source" , pa .string ()),
2929 pa .field ("run_date" , pa .date32 ()),
3030 pa .field ("run_type" , pa .string ()),
31- pa .field ("action" , pa .string ()),
3231 pa .field ("run_id" , pa .string ()),
32+ pa .field ("action" , pa .string ()),
3333 )
3434)
3535
3636TIMDEX_DATASET_PARTITION_COLUMNS = [
3737 "source" ,
3838 "run_date" ,
3939 "run_type" ,
40- "action" ,
4140 "run_id" ,
41+ "action" ,
4242]
4343
4444DEFAULT_BATCH_SIZE = 1_000
@@ -54,6 +54,7 @@ def __init__(self, location: str | list[str]):
5454 self .dataset : ds .Dataset = None # type: ignore[assignment]
5555 self .schema = TIMDEX_DATASET_SCHEMA
5656 self .partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
57+ self ._written_files : list [ds .WrittenFile ] = None # type: ignore[assignment]
5758
5859 @classmethod
5960 def load (cls , location : str ) -> "TIMDEXDataset" :
@@ -197,6 +198,7 @@ def write(
197198 - use_threads: boolean if threads should be used for writing
198199 """
199200 start_time = time .perf_counter ()
201+ self ._written_files = []
200202
201203 if isinstance (self .source , list ):
202204 raise TypeError (
@@ -209,14 +211,13 @@ def write(
209211 batch_size = batch_size ,
210212 )
211213
212- written_files = []
213214 ds .write_dataset (
214215 record_batches_iter ,
215216 base_dir = self .source ,
216217 basename_template = "%s-{i}.parquet" % (str (uuid .uuid4 ())), # noqa: UP031
217218 existing_data_behavior = "delete_matching" ,
218219 filesystem = self .filesystem ,
219- file_visitor = lambda written_file : written_files . append (written_file ),
220+ file_visitor = lambda written_file : self . _written_files . append (written_file ), # type: ignore[arg-type]
220221 format = "parquet" ,
221222 max_open_files = 500 ,
222223 max_rows_per_file = MAX_ROWS_PER_FILE ,
@@ -227,8 +228,8 @@ def write(
227228 use_threads = use_threads ,
228229 )
229230
230- logger . info ( f"write elapsed: { round ( time . perf_counter () - start_time , 2 ) } s" )
231- return written_files # type: ignore[return-value]
231+ self . log_write_statistics ( start_time )
232+ return self . _written_files # type: ignore[return-value]
232233
233234 def get_dataset_record_batches (
234235 self ,
@@ -266,3 +267,24 @@ def get_dataset_record_batches(
266267 f"elapsed: { round (time .perf_counter ()- batch_start_time , 6 )} s"
267268 )
268269 yield batch
270+
271+ def log_write_statistics (self , start_time : float ) -> None :
272+ """Parse written files from write and log statistics."""
273+ total_time = round (time .perf_counter () - start_time , 2 )
274+ total_files = len (self ._written_files )
275+ total_rows = sum (
276+ [
277+ wf .metadata .num_rows # type: ignore[attr-defined]
278+ for wf in self ._written_files
279+ ]
280+ )
281+ total_size = sum (
282+ [wf .size for wf in self ._written_files ] # type: ignore[attr-defined]
283+ )
284+ logger .info (
285+ f"Dataset write complete - elapsed: "
286+ f"{ total_time } s, "
287+ f"total files: { total_files } , "
288+ f"total rows: { total_rows } , "
289+ f"total size: { total_size } "
290+ )
0 commit comments