Skip to content

Commit e0504fe

Browse files
authored
get-dataset metastore: exclude preview and versions by default (#1661)
* get-dataset metastore: exclude preview * add filter to exclude versions * more cleanup after review * fix listing dataset fetching * fix tests: update datasets and tests * fix listing requires versions * adjust test fixtures * fix more places where we need versions after all * fix tests: explicitily require version where it is needed * add guardrails around fields not loaded * more tests and dataset records cleanup * refactor a bit more data classes * more fixes * refactor dataclass serializers * more review comments addressed * add versions setter * fix falkey listing tests * add more tests for DatasetListRecord * cleanup from_dict a bit more * address review * cleanup dead code * more cleanup after reviews * address review * mark preview loaded on update dataset * drop serialized schema from the models
1 parent 9a67fc3 commit e0504fe

40 files changed

+1053
-395
lines changed

src/datachain/catalog/catalog.py

Lines changed: 53 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import posixpath
66
import sys
77
import time
8-
import traceback
98
from collections.abc import Callable, Iterable, Iterator, Sequence
109
from contextlib import contextmanager, suppress
1110
from copy import copy
@@ -66,6 +65,7 @@
6665
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
6766
from datachain.dataset import DatasetListVersion
6867
from datachain.job import Job
68+
from datachain.lib.dc.datachain import DataChain
6969
from datachain.lib.listing_info import ListingInfo
7070
from datachain.listing import Listing
7171
from datachain.remote.studio import StudioClient
@@ -77,7 +77,6 @@
7777
CHECKPOINTS_TTL = 4 * 60 * 60
7878

7979
INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing"
80-
DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
8180
# exit code we use if query script was canceled
8281
QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
8382
# exit code we use if the job is already in a terminal state (failed/canceled elsewhere)
@@ -618,17 +617,6 @@ def enlist_source(
618617

619618
return lst, client, list_path
620619

621-
def _remove_dataset_rows_and_warehouse_info(
622-
self, dataset: DatasetRecord, version: str, **kwargs
623-
):
624-
self.warehouse.drop_dataset_rows_table(dataset, version)
625-
self.update_dataset_version_with_warehouse_info(
626-
dataset,
627-
version,
628-
rows_dropped=True,
629-
**kwargs,
630-
)
631-
632620
@contextmanager
633621
def enlist_sources(
634622
self,
@@ -697,6 +685,7 @@ def _row_to_node(d: dict[str, Any]) -> Node:
697685
ds_name,
698686
namespace_name=ds_namespace,
699687
project_name=ds_project,
688+
versions=[ds_version] if ds_version else None,
700689
include_incomplete=False,
701690
)
702691
if not ds_version:
@@ -802,6 +791,7 @@ def create_dataset(
802791
columns: Sequence[Column],
803792
feature_schema: dict | None = None,
804793
query_script: str = "",
794+
sources: str = "",
805795
validate_version: bool | None = True,
806796
listing: bool | None = False,
807797
uuid: str | None = None,
@@ -827,6 +817,7 @@ def create_dataset(
827817
name,
828818
namespace_name=project.namespace.name if project else None,
829819
project_name=project.name if project else None,
820+
versions=None,
830821
)
831822

832823
if (description or attrs) and (
@@ -864,6 +855,7 @@ def create_dataset(
864855
project=project,
865856
feature_schema=feature_schema,
866857
query_script=query_script,
858+
sources=sources,
867859
columns=columns,
868860
uuid=uuid,
869861
job_id=job_id,
@@ -873,13 +865,8 @@ def create_dataset(
873865

874866
@staticmethod
875867
def _next_auto_version(dataset: "DatasetRecord", update_version: str | None) -> str:
876-
"""Compute the next version for a dataset based on the update strategy.
877-
878-
Handles brand-new datasets whose versions list may contain a single
879-
phantom entry with ``version=None`` (artifact of the LEFT JOIN used
880-
by ``get_dataset``).
881-
"""
882-
if not any(v.version for v in dataset.versions):
868+
"""Compute the next version for a dataset based on the update strategy."""
869+
if not dataset.versions:
883870
return DEFAULT_DATASET_VERSION
884871
if update_version == "major":
885872
return dataset.next_version_major
@@ -895,6 +882,7 @@ def _try_claim_version(
895882
project: Project | None,
896883
feature_schema: dict | None,
897884
query_script: str,
885+
sources: str,
898886
columns: Sequence[Column],
899887
uuid: str | None,
900888
job_id: str | None,
@@ -929,6 +917,7 @@ def _try_claim_version(
929917
target_version,
930918
feature_schema=feature_schema,
931919
query_script=query_script,
920+
sources=sources,
932921
columns=columns,
933922
uuid=uuid,
934923
job_id=job_id,
@@ -953,6 +942,7 @@ def _try_claim_version(
953942
name,
954943
namespace_name=project.namespace.name if project else None,
955944
project_name=project.name if project else None,
945+
versions=None,
956946
)
957947
target_version = self._next_auto_version(dataset, update_version)
958948

@@ -1013,21 +1003,19 @@ def create_dataset_version(
10131003
return dataset, version_created
10141004

10151005
def update_dataset_version_with_warehouse_info(
1016-
self, dataset: DatasetRecord, version: str, rows_dropped=False, **kwargs
1006+
self, dataset: DatasetRecord, version: str, **kwargs
10171007
) -> None:
10181008
from datachain.query.dataset import DatasetQuery
10191009

10201010
dataset_version = dataset.get_version(version)
1011+
if dataset_version._preview_loaded:
1012+
raise RuntimeError(
1013+
"update_dataset_version_with_warehouse_info expects preview to be "
1014+
"unloaded and regenerates it from warehouse rows"
1015+
)
10211016

10221017
values = {**kwargs}
10231018

1024-
if rows_dropped:
1025-
values["num_objects"] = None
1026-
values["size"] = None
1027-
values["preview"] = None
1028-
self.metastore.update_dataset_version(dataset, version, **values)
1029-
return
1030-
10311019
stats_num_objects = None
10321020
stats_size = None
10331021
if not dataset_version.num_objects:
@@ -1039,22 +1027,20 @@ def update_dataset_version_with_warehouse_info(
10391027
if size != dataset_version.size:
10401028
values["size"] = size
10411029

1042-
preview_rows = None
1043-
if not dataset_version.preview:
1044-
preview = (
1045-
DatasetQuery(
1046-
name=dataset.name,
1047-
namespace_name=dataset.project.namespace.name,
1048-
project_name=dataset.project.name,
1049-
version=version,
1050-
catalog=self,
1051-
include_incomplete=True, # Allow reading CREATED version
1052-
)
1053-
.limit(20)
1054-
.to_db_records()
1030+
preview = (
1031+
DatasetQuery(
1032+
name=dataset.name,
1033+
namespace_name=dataset.project.namespace.name,
1034+
project_name=dataset.project.name,
1035+
version=version,
1036+
catalog=self,
1037+
include_incomplete=True, # Allow reading CREATED version
10551038
)
1056-
preview_rows = len(preview)
1057-
values["preview"] = preview
1039+
.limit(20)
1040+
.to_db_records()
1041+
)
1042+
preview_rows = len(preview)
1043+
values["preview"] = preview
10581044

10591045
# Log anomaly: dataset_stats returned 0 but preview has data
10601046
if stats_num_objects == 0 and preview_rows and preview_rows > 0:
@@ -1068,7 +1054,7 @@ def update_dataset_version_with_warehouse_info(
10681054
version,
10691055
dataset_version.num_objects,
10701056
dataset_version.size,
1071-
bool(dataset_version.preview),
1057+
False,
10721058
stats_num_objects,
10731059
stats_size,
10741060
preview_rows,
@@ -1175,7 +1161,7 @@ def create_dataset_from_sources(
11751161
project: Project | None = None,
11761162
client_config=None,
11771163
recursive=False,
1178-
) -> DatasetRecord:
1164+
) -> "DataChain":
11791165
if not sources:
11801166
raise ValueError("Sources needs to be non empty list")
11811167

@@ -1188,54 +1174,20 @@ def create_dataset_from_sources(
11881174
if source.startswith(DATASET_PREFIX):
11891175
dc = read_dataset(source[len(DATASET_PREFIX) :], session=self.session)
11901176
else:
1191-
dc = read_storage(source, session=self.session, recursive=recursive)
1177+
dc = read_storage(
1178+
source,
1179+
session=self.session,
1180+
recursive=recursive,
1181+
client_config=client_config,
1182+
)
11921183

11931184
chains.append(dc)
11941185

11951186
# create union of all dataset queries created from sources
1196-
dc = reduce(lambda dc1, dc2: dc1.union(dc2), chains)
1197-
try:
1198-
dc = dc.settings(project=project.name, namespace=project.namespace.name)
1199-
dc.save(name, query_script="")
1200-
except Exception as e: # noqa: BLE001
1201-
try:
1202-
ds = self.get_dataset(
1203-
name,
1204-
namespace_name=project.namespace.name,
1205-
project_name=project.name,
1206-
)
1207-
self.metastore.update_dataset_status(
1208-
ds,
1209-
DatasetStatus.FAILED,
1210-
version=ds.latest_version,
1211-
error_message=DATASET_INTERNAL_ERROR_MESSAGE,
1212-
error_stack=traceback.format_exc(),
1213-
)
1214-
self._remove_dataset_rows_and_warehouse_info(
1215-
ds,
1216-
ds.latest_version,
1217-
sources="\n".join(sources),
1218-
)
1219-
raise
1220-
except DatasetNotFoundError:
1221-
raise e from None
1222-
1223-
ds = self.get_dataset(
1224-
name,
1225-
namespace_name=project.namespace.name,
1226-
project_name=project.name,
1227-
)
1228-
1229-
self.update_dataset_version_with_warehouse_info(
1230-
ds,
1231-
ds.latest_version,
1232-
sources="\n".join(sources),
1233-
)
1234-
1235-
return self.get_dataset(
1236-
name,
1237-
namespace_name=project.namespace.name,
1238-
project_name=project.name,
1187+
return (
1188+
reduce(lambda dc1, dc2: dc1.union(dc2), chains)
1189+
.settings(project=project.name, namespace=project.namespace.name)
1190+
.save(name, sources="\n".join(sources), query_script="")
12391191
)
12401192

12411193
def get_full_dataset_name(
@@ -1299,7 +1251,10 @@ def get_dataset(
12991251
name: str,
13001252
namespace_name: str | None = None,
13011253
project_name: str | None = None,
1254+
*,
1255+
versions: Sequence[str] | None = (),
13021256
include_incomplete: bool = True,
1257+
include_preview: bool = False,
13031258
) -> DatasetRecord:
13041259
from datachain.lib.listing import is_listing_dataset
13051260

@@ -1314,7 +1269,9 @@ def get_dataset(
13141269
name,
13151270
namespace_name=namespace_name,
13161271
project_name=project_name,
1272+
versions=versions,
13171273
include_incomplete=include_incomplete,
1274+
include_preview=include_preview,
13181275
)
13191276

13201277
def get_dataset_with_remote_fallback(
@@ -1348,6 +1305,7 @@ def get_dataset_with_remote_fallback(
13481305
name,
13491306
namespace_name=namespace_name,
13501307
project_name=project_name,
1308+
versions=None,
13511309
include_incomplete=include_incomplete,
13521310
)
13531311
if not version or ds.has_version(version):
@@ -1377,6 +1335,7 @@ def get_dataset_with_remote_fallback(
13771335
name,
13781336
namespace_name=namespace_name,
13791337
project_name=project_name,
1338+
versions=None,
13801339
include_incomplete=include_incomplete,
13811340
)
13821341

@@ -1445,6 +1404,7 @@ def get_dataset_dependencies(
14451404
name,
14461405
namespace_name=namespace_name,
14471406
project_name=project_name,
1407+
versions=[version],
14481408
include_incomplete=False,
14491409
)
14501410
dataset_version = dataset.get_version(version)
@@ -1618,6 +1578,7 @@ def export_dataset_table(
16181578
name,
16191579
namespace_name=project.namespace.name if project else None,
16201580
project_name=project.name if project else None,
1581+
versions=[version],
16211582
)
16221583

16231584
self.warehouse.export_dataset_table(
@@ -1640,6 +1601,7 @@ def remove_dataset(
16401601
name,
16411602
namespace_name=project.namespace.name if project else None,
16421603
project_name=project.name if project else None,
1604+
versions=None,
16431605
)
16441606
if not version and not force:
16451607
raise ValueError(f"Missing dataset version from input for dataset {name}")
@@ -1680,6 +1642,7 @@ def edit_dataset(
16801642
name,
16811643
namespace_name=project.namespace.name if project else None,
16821644
project_name=project.name if project else None,
1645+
versions=None,
16831646
)
16841647
return self.update_dataset(dataset, **update_data)
16851648

@@ -1857,6 +1820,7 @@ def pull_dataset( # noqa: C901, PLR0915, PLR0912
18571820
local_ds_name,
18581821
namespace_name=namespace.name,
18591822
project_name=project.name,
1823+
versions=None,
18601824
include_incomplete=True,
18611825
)
18621826
if local_dataset.has_version(local_ds_version):

src/datachain/cli/commands/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def list_datasets_local_versions(
127127
name,
128128
namespace_name=namespace_name,
129129
project_name=project_name,
130+
versions=None,
130131
include_incomplete=False,
131132
)
132133
for v in ds.versions:

src/datachain/cli/commands/show.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def show(
2525
version = name_version
2626

2727
if script:
28-
dataset = catalog.get_dataset(name, include_incomplete=False)
28+
dataset = catalog.get_dataset(
29+
name,
30+
versions=[version] if version else None,
31+
include_incomplete=False,
32+
)
2933
dataset_version = dataset.get_version(version or dataset.latest_version)
3034
print(dataset_version.query_script)
3135
return

0 commit comments

Comments
 (0)