Skip to content

Commit d7a1b29

Browse files
committed
Save written files to TIMDEXDataset object and log
1 parent d30c9e2 commit d7a1b29

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

timdex_dataset_api/dataset.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, location: str | list[str]):
5454
self.dataset: ds.Dataset = None # type: ignore[assignment]
5555
self.schema = TIMDEX_DATASET_SCHEMA
5656
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
57+
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
5758

5859
@classmethod
5960
def load(cls, location: str) -> "TIMDEXDataset":
@@ -197,6 +198,7 @@ def write(
197198
- use_threads: boolean if threads should be used for writing
198199
"""
199200
start_time = time.perf_counter()
201+
self._written_files = []
200202

201203
if isinstance(self.source, list):
202204
raise TypeError(
@@ -209,14 +211,13 @@ def write(
209211
batch_size=batch_size,
210212
)
211213

212-
written_files = []
213214
ds.write_dataset(
214215
record_batches_iter,
215216
base_dir=self.source,
216217
basename_template="%s-{i}.parquet" % (str(uuid.uuid4())), # noqa: UP031
217218
existing_data_behavior="delete_matching",
218219
filesystem=self.filesystem,
219-
file_visitor=lambda written_file: written_files.append(written_file),
220+
file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type]
220221
format="parquet",
221222
max_open_files=500,
222223
max_rows_per_file=MAX_ROWS_PER_FILE,
@@ -227,8 +228,8 @@ def write(
227228
use_threads=use_threads,
228229
)
229230

230-
logger.info(f"write elapsed: {round(time.perf_counter()-start_time, 2)}s")
231-
return written_files # type: ignore[return-value]
231+
self.log_write_statistics(start_time)
232+
return self._written_files # type: ignore[return-value]
232233

233234
def get_dataset_record_batches(
234235
self,
@@ -266,3 +267,24 @@ def get_dataset_record_batches(
266267
f"elapsed: {round(time.perf_counter()-batch_start_time, 6)}s"
267268
)
268269
yield batch
270+
271+
def log_write_statistics(self, start_time: float) -> None:
272+
"""Parse written files from write and log statistics."""
273+
total_time = round(time.perf_counter() - start_time, 2)
274+
total_files = len(self._written_files)
275+
total_rows = sum(
276+
[
277+
wf.metadata.num_rows # type: ignore[attr-defined]
278+
for wf in self._written_files
279+
]
280+
)
281+
total_size = sum(
282+
[wf.size for wf in self._written_files] # type: ignore[attr-defined]
283+
)
284+
logger.info(
285+
f"Dataset write complete - elapsed: "
286+
f"{total_time}s, "
287+
f"total files: {total_files}, "
288+
f"total rows: {total_rows}, "
289+
f"total size: {total_size}"
290+
)

0 commit comments

Comments
 (0)