55import posixpath
66import sys
77import time
8- import traceback
98from collections .abc import Callable , Iterable , Iterator , Sequence
109from contextlib import contextmanager , suppress
1110from copy import copy
6665 from datachain .data_storage import AbstractMetastore , AbstractWarehouse
6766 from datachain .dataset import DatasetListVersion
6867 from datachain .job import Job
68+ from datachain .lib .dc .datachain import DataChain
6969 from datachain .lib .listing_info import ListingInfo
7070 from datachain .listing import Listing
7171 from datachain .remote .studio import StudioClient
7777CHECKPOINTS_TTL = 4 * 60 * 60
7878
7979INDEX_INTERNAL_ERROR_MESSAGE = "Internal error on indexing"
80- DATASET_INTERNAL_ERROR_MESSAGE = "Internal error on creating dataset"
8180# exit code we use if query script was canceled
8281QUERY_SCRIPT_CANCELED_EXIT_CODE = 11
8382# exit code we use if the job is already in a terminal state (failed/canceled elsewhere)
@@ -618,17 +617,6 @@ def enlist_source(
618617
619618 return lst , client , list_path
620619
621- def _remove_dataset_rows_and_warehouse_info (
622- self , dataset : DatasetRecord , version : str , ** kwargs
623- ):
624- self .warehouse .drop_dataset_rows_table (dataset , version )
625- self .update_dataset_version_with_warehouse_info (
626- dataset ,
627- version ,
628- rows_dropped = True ,
629- ** kwargs ,
630- )
631-
632620 @contextmanager
633621 def enlist_sources (
634622 self ,
@@ -697,6 +685,7 @@ def _row_to_node(d: dict[str, Any]) -> Node:
697685 ds_name ,
698686 namespace_name = ds_namespace ,
699687 project_name = ds_project ,
688+ versions = [ds_version ] if ds_version else None ,
700689 include_incomplete = False ,
701690 )
702691 if not ds_version :
@@ -802,6 +791,7 @@ def create_dataset(
802791 columns : Sequence [Column ],
803792 feature_schema : dict | None = None ,
804793 query_script : str = "" ,
794+ sources : str = "" ,
805795 validate_version : bool | None = True ,
806796 listing : bool | None = False ,
807797 uuid : str | None = None ,
@@ -827,6 +817,7 @@ def create_dataset(
827817 name ,
828818 namespace_name = project .namespace .name if project else None ,
829819 project_name = project .name if project else None ,
820+ versions = None ,
830821 )
831822
832823 if (description or attrs ) and (
@@ -864,6 +855,7 @@ def create_dataset(
864855 project = project ,
865856 feature_schema = feature_schema ,
866857 query_script = query_script ,
858+ sources = sources ,
867859 columns = columns ,
868860 uuid = uuid ,
869861 job_id = job_id ,
@@ -873,13 +865,8 @@ def create_dataset(
873865
874866 @staticmethod
875867 def _next_auto_version (dataset : "DatasetRecord" , update_version : str | None ) -> str :
876- """Compute the next version for a dataset based on the update strategy.
877-
878- Handles brand-new datasets whose versions list may contain a single
879- phantom entry with ``version=None`` (artifact of the LEFT JOIN used
880- by ``get_dataset``).
881- """
882- if not any (v .version for v in dataset .versions ):
868+ """Compute the next version for a dataset based on the update strategy."""
869+ if not dataset .versions :
883870 return DEFAULT_DATASET_VERSION
884871 if update_version == "major" :
885872 return dataset .next_version_major
@@ -895,6 +882,7 @@ def _try_claim_version(
895882 project : Project | None ,
896883 feature_schema : dict | None ,
897884 query_script : str ,
885+ sources : str ,
898886 columns : Sequence [Column ],
899887 uuid : str | None ,
900888 job_id : str | None ,
@@ -929,6 +917,7 @@ def _try_claim_version(
929917 target_version ,
930918 feature_schema = feature_schema ,
931919 query_script = query_script ,
920+ sources = sources ,
932921 columns = columns ,
933922 uuid = uuid ,
934923 job_id = job_id ,
@@ -953,6 +942,7 @@ def _try_claim_version(
953942 name ,
954943 namespace_name = project .namespace .name if project else None ,
955944 project_name = project .name if project else None ,
945+ versions = None ,
956946 )
957947 target_version = self ._next_auto_version (dataset , update_version )
958948
@@ -1013,21 +1003,19 @@ def create_dataset_version(
10131003 return dataset , version_created
10141004
10151005 def update_dataset_version_with_warehouse_info (
1016- self , dataset : DatasetRecord , version : str , rows_dropped = False , ** kwargs
1006+ self , dataset : DatasetRecord , version : str , ** kwargs
10171007 ) -> None :
10181008 from datachain .query .dataset import DatasetQuery
10191009
10201010 dataset_version = dataset .get_version (version )
1011+ if dataset_version ._preview_loaded :
1012+ raise RuntimeError (
1013+ "update_dataset_version_with_warehouse_info expects preview to be "
1014+ "unloaded and regenerates it from warehouse rows"
1015+ )
10211016
10221017 values = {** kwargs }
10231018
1024- if rows_dropped :
1025- values ["num_objects" ] = None
1026- values ["size" ] = None
1027- values ["preview" ] = None
1028- self .metastore .update_dataset_version (dataset , version , ** values )
1029- return
1030-
10311019 stats_num_objects = None
10321020 stats_size = None
10331021 if not dataset_version .num_objects :
@@ -1039,22 +1027,20 @@ def update_dataset_version_with_warehouse_info(
10391027 if size != dataset_version .size :
10401028 values ["size" ] = size
10411029
1042- preview_rows = None
1043- if not dataset_version .preview :
1044- preview = (
1045- DatasetQuery (
1046- name = dataset .name ,
1047- namespace_name = dataset .project .namespace .name ,
1048- project_name = dataset .project .name ,
1049- version = version ,
1050- catalog = self ,
1051- include_incomplete = True , # Allow reading CREATED version
1052- )
1053- .limit (20 )
1054- .to_db_records ()
1030+ preview = (
1031+ DatasetQuery (
1032+ name = dataset .name ,
1033+ namespace_name = dataset .project .namespace .name ,
1034+ project_name = dataset .project .name ,
1035+ version = version ,
1036+ catalog = self ,
1037+ include_incomplete = True , # Allow reading CREATED version
10551038 )
1056- preview_rows = len (preview )
1057- values ["preview" ] = preview
1039+ .limit (20 )
1040+ .to_db_records ()
1041+ )
1042+ preview_rows = len (preview )
1043+ values ["preview" ] = preview
10581044
10591045 # Log anomaly: dataset_stats returned 0 but preview has data
10601046 if stats_num_objects == 0 and preview_rows and preview_rows > 0 :
@@ -1068,7 +1054,7 @@ def update_dataset_version_with_warehouse_info(
10681054 version ,
10691055 dataset_version .num_objects ,
10701056 dataset_version .size ,
1071- bool ( dataset_version . preview ) ,
1057+ False ,
10721058 stats_num_objects ,
10731059 stats_size ,
10741060 preview_rows ,
@@ -1175,7 +1161,7 @@ def create_dataset_from_sources(
11751161 project : Project | None = None ,
11761162 client_config = None ,
11771163 recursive = False ,
1178- ) -> DatasetRecord :
1164+ ) -> "DataChain" :
11791165 if not sources :
11801166 raise ValueError ("Sources needs to be non empty list" )
11811167
@@ -1188,54 +1174,20 @@ def create_dataset_from_sources(
11881174 if source .startswith (DATASET_PREFIX ):
11891175 dc = read_dataset (source [len (DATASET_PREFIX ) :], session = self .session )
11901176 else :
1191- dc = read_storage (source , session = self .session , recursive = recursive )
1177+ dc = read_storage (
1178+ source ,
1179+ session = self .session ,
1180+ recursive = recursive ,
1181+ client_config = client_config ,
1182+ )
11921183
11931184 chains .append (dc )
11941185
11951186 # create union of all dataset queries created from sources
1196- dc = reduce (lambda dc1 , dc2 : dc1 .union (dc2 ), chains )
1197- try :
1198- dc = dc .settings (project = project .name , namespace = project .namespace .name )
1199- dc .save (name , query_script = "" )
1200- except Exception as e : # noqa: BLE001
1201- try :
1202- ds = self .get_dataset (
1203- name ,
1204- namespace_name = project .namespace .name ,
1205- project_name = project .name ,
1206- )
1207- self .metastore .update_dataset_status (
1208- ds ,
1209- DatasetStatus .FAILED ,
1210- version = ds .latest_version ,
1211- error_message = DATASET_INTERNAL_ERROR_MESSAGE ,
1212- error_stack = traceback .format_exc (),
1213- )
1214- self ._remove_dataset_rows_and_warehouse_info (
1215- ds ,
1216- ds .latest_version ,
1217- sources = "\n " .join (sources ),
1218- )
1219- raise
1220- except DatasetNotFoundError :
1221- raise e from None
1222-
1223- ds = self .get_dataset (
1224- name ,
1225- namespace_name = project .namespace .name ,
1226- project_name = project .name ,
1227- )
1228-
1229- self .update_dataset_version_with_warehouse_info (
1230- ds ,
1231- ds .latest_version ,
1232- sources = "\n " .join (sources ),
1233- )
1234-
1235- return self .get_dataset (
1236- name ,
1237- namespace_name = project .namespace .name ,
1238- project_name = project .name ,
1187+ return (
1188+ reduce (lambda dc1 , dc2 : dc1 .union (dc2 ), chains )
1189+ .settings (project = project .name , namespace = project .namespace .name )
1190+ .save (name , sources = "\n " .join (sources ), query_script = "" )
12391191 )
12401192
12411193 def get_full_dataset_name (
@@ -1299,7 +1251,10 @@ def get_dataset(
12991251 name : str ,
13001252 namespace_name : str | None = None ,
13011253 project_name : str | None = None ,
1254+ * ,
1255+ versions : Sequence [str ] | None = (),
13021256 include_incomplete : bool = True ,
1257+ include_preview : bool = False ,
13031258 ) -> DatasetRecord :
13041259 from datachain .lib .listing import is_listing_dataset
13051260
@@ -1314,7 +1269,9 @@ def get_dataset(
13141269 name ,
13151270 namespace_name = namespace_name ,
13161271 project_name = project_name ,
1272+ versions = versions ,
13171273 include_incomplete = include_incomplete ,
1274+ include_preview = include_preview ,
13181275 )
13191276
13201277 def get_dataset_with_remote_fallback (
@@ -1348,6 +1305,7 @@ def get_dataset_with_remote_fallback(
13481305 name ,
13491306 namespace_name = namespace_name ,
13501307 project_name = project_name ,
1308+ versions = None ,
13511309 include_incomplete = include_incomplete ,
13521310 )
13531311 if not version or ds .has_version (version ):
@@ -1377,6 +1335,7 @@ def get_dataset_with_remote_fallback(
13771335 name ,
13781336 namespace_name = namespace_name ,
13791337 project_name = project_name ,
1338+ versions = None ,
13801339 include_incomplete = include_incomplete ,
13811340 )
13821341
@@ -1445,6 +1404,7 @@ def get_dataset_dependencies(
14451404 name ,
14461405 namespace_name = namespace_name ,
14471406 project_name = project_name ,
1407+ versions = [version ],
14481408 include_incomplete = False ,
14491409 )
14501410 dataset_version = dataset .get_version (version )
@@ -1618,6 +1578,7 @@ def export_dataset_table(
16181578 name ,
16191579 namespace_name = project .namespace .name if project else None ,
16201580 project_name = project .name if project else None ,
1581+ versions = [version ],
16211582 )
16221583
16231584 self .warehouse .export_dataset_table (
@@ -1640,6 +1601,7 @@ def remove_dataset(
16401601 name ,
16411602 namespace_name = project .namespace .name if project else None ,
16421603 project_name = project .name if project else None ,
1604+ versions = None ,
16431605 )
16441606 if not version and not force :
16451607 raise ValueError (f"Missing dataset version from input for dataset { name } " )
@@ -1680,6 +1642,7 @@ def edit_dataset(
16801642 name ,
16811643 namespace_name = project .namespace .name if project else None ,
16821644 project_name = project .name if project else None ,
1645+ versions = None ,
16831646 )
16841647 return self .update_dataset (dataset , ** update_data )
16851648
@@ -1857,6 +1820,7 @@ def pull_dataset( # noqa: C901, PLR0915, PLR0912
18571820 local_ds_name ,
18581821 namespace_name = namespace .name ,
18591822 project_name = project .name ,
1823+ versions = None ,
18601824 include_incomplete = True ,
18611825 )
18621826 if local_dataset .has_version (local_ds_version ):
0 commit comments