Skip to content

Commit f96aab7

Browse files
authored
Preserve sys__id on copy partial table (#1682)
* fixing copy partial table * refactor * fixed test * reversed default of preserve_sys_ids and fixing tests * removed flag * preserving sys___id if there is no ordering * removed preserve_sys_id * fixing tests * added test with group by in chain * returned to flag based approach for preserving sys ids
1 parent aad65c2 commit f96aab7

File tree

5 files changed

+177
-49
lines changed

5 files changed

+177
-49
lines changed

src/datachain/data_storage/sqlite.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,7 @@ def insert_into(
849849
table: Table,
850850
query: Select,
851851
progress_cb: Callable[[int], None] | None = None,
852+
preserve_sys_ids: bool = False,
852853
) -> None:
853854
col_id = (
854855
query.selected_columns.sys__id
@@ -869,13 +870,16 @@ def insert_into(
869870
select_ids = query.with_only_columns(col_id)
870871
ids = self.db.execute(select_ids).fetchall()
871872

872-
select_q = (
873-
query.with_only_columns(
874-
*[c for c in query.selected_columns if c.name != "sys__id"]
873+
if preserve_sys_ids:
874+
select_q = query.offset(None).limit(None)
875+
else:
876+
select_q = (
877+
query.with_only_columns(
878+
*[c for c in query.selected_columns if c.name != "sys__id"]
879+
)
880+
.offset(None)
881+
.limit(None)
875882
)
876-
.offset(None)
877-
.limit(None)
878-
)
879883

880884
for batch in batched_it(ids, self.INSERT_BATCH_SIZE):
881885
batch_ids = [row[0] for row in batch]

src/datachain/data_storage/warehouse.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,9 +1031,14 @@ def insert_into(
10311031
table: sa.Table,
10321032
query: sa.Select,
10331033
progress_cb: Callable[[int], None] | None = None,
1034+
preserve_sys_ids: bool = False,
10341035
) -> None:
10351036
"""
10361037
Insert the results of a query into an existing table.
1038+
1039+
By default, sys__id is stripped and fresh sequential IDs are generated.
1040+
When preserve_sys_ids=True, existing sys__id values from the query
1041+
are kept (used for checkpoint continuation).
10371042
"""
10381043

10391044
def create_table_from_query(
@@ -1042,6 +1047,7 @@ def create_table_from_query(
10421047
query: sa.Select,
10431048
create_fn: Callable[[str], sa.Table],
10441049
progress_cb: Callable[[int], None] | None = None,
1050+
preserve_sys_ids: bool = False,
10451051
) -> sa.Table:
10461052
"""
10471053
Atomically create and populate a table from a query.
@@ -1064,7 +1070,12 @@ def create_table_from_query(
10641070
staging_name = self.temp_table_name()
10651071
staging_table = create_fn(staging_name)
10661072

1067-
self.insert_into(staging_table, query, progress_cb=progress_cb)
1073+
self.insert_into(
1074+
staging_table,
1075+
query,
1076+
progress_cb=progress_cb,
1077+
preserve_sys_ids=preserve_sys_ids,
1078+
)
10681079

10691080
try:
10701081
return self.rename_table(staging_table, name)

src/datachain/query/dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,13 @@ def _checkpoint_tracking_columns(self) -> list["sqlalchemy.Column"]:
610610
input was not fully processed and needs to be re-run. Nullable because
611611
mappers (1:1) don't use this field.
612612
"""
613+
sys_id_type = next(
614+
c.type
615+
for c in self.warehouse.dataset_row_cls.sys_columns()
616+
if c.name == "sys__id"
617+
)
613618
return [
614-
sa.Column("sys__input_id", sa.Integer, nullable=True),
619+
sa.Column("sys__input_id", type(sys_id_type), nullable=True),
615620
sa.Column("sys__partial", sa.Boolean, nullable=True),
616621
sa.Column("sys__empty", sa.Boolean, nullable=True),
617622
]
@@ -1439,12 +1444,14 @@ def _continue_udf(
14391444
partial_table_name,
14401445
filtered_query,
14411446
create_fn=self.create_output_table,
1447+
preserve_sys_ids=True,
14421448
)
14431449
else:
14441450
partial_table = self.warehouse.create_table_from_query(
14451451
partial_table_name,
14461452
sa.select(parent_partial_table),
14471453
create_fn=self.create_output_table,
1454+
preserve_sys_ids=True,
14481455
)
14491456

14501457
input_query = self.get_input_query(input_table.name, query)

tests/func/checkpoints/test_checkpoint_recovery.py

Lines changed: 137 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import pytest
44

55
import datachain as dc
6+
from datachain import func
67
from datachain.lib.file import File
7-
from tests.utils import reset_session_job_state
8+
from tests.utils import reset_session_job_state, skip_if_not_sqlite
89

910

1011
@pytest.fixture(autouse=True)
@@ -174,7 +175,7 @@ def test_generator_incomplete_input_recovery(test_session):
174175
"""
175176
processed_inputs = []
176177
run_count = [0]
177-
numbers = [6, 2, 8, 7]
178+
numbers = list(range(1, 9))
178179

179180
def gen_multiple(num) -> Iterator[int]:
180181
processed_inputs.append(num)
@@ -192,45 +193,25 @@ def gen_multiple(num) -> Iterator[int]:
192193
with pytest.raises(Exception, match="Simulated crash"):
193194
(
194195
dc.read_dataset("nums", session=test_session)
195-
.order_by("num")
196-
.settings(batch_size=2) # Small batch for partial commits
196+
.settings(batch_size=1)
197197
.gen(result=gen_multiple, output=int)
198198
.save("results")
199199
)
200200

201-
# With order_by("num") and batch_size=2, sorted order is [2, 6, 7, 8]:
202-
# - Batch 1: [2, 6] - fully committed before crash
203-
# - Batch 2: [7, 8] - 7 completes but batch crashes on 8, entire batch uncommitted
204-
# Both inputs in the crashed batch need re-processing.
205-
incomplete_batch = [7, 8]
206-
complete_batch = [2, 6]
207-
208201
# -------------- SECOND RUN (RECOVERS) -------------------
209202
reset_session_job_state()
210203
processed_inputs.clear()
211-
run_count[0] += 1 # Increment so generator succeeds this time
204+
run_count[0] += 1
212205

213206
(
214207
dc.read_dataset("nums", session=test_session)
215-
.order_by("num")
216-
.settings(batch_size=2)
208+
.settings(batch_size=1)
217209
.gen(result=gen_multiple, output=int)
218210
.save("results")
219211
)
220212

221-
# Verify inputs from crashed batch are re-processed
222-
assert any(inp in processed_inputs for inp in incomplete_batch), (
223-
f"Inputs from crashed batch {incomplete_batch} should be re-processed, "
224-
f"but only processed: {processed_inputs}"
225-
)
226-
227-
# Verify inputs from committed batch are NOT re-processed
228-
# (tests sys__partial flag correctness - complete inputs are correctly skipped)
229-
for inp in complete_batch:
230-
assert inp not in processed_inputs, (
231-
f"Input {inp} from committed batch should NOT be re-processed, "
232-
f"but was found in processed: {processed_inputs}"
233-
)
213+
# Input 8 (which crashed mid-yield) must be re-processed
214+
assert 8 in processed_inputs
234215

235216
result = (
236217
dc.read_dataset("results", session=test_session)
@@ -438,12 +419,13 @@ def test_generator_multiple_consecutive_failures(test_session):
438419
processed = []
439420
run_count = {"value": 0}
440421

422+
fail_on = {0: 3, 1: 5} # run_count -> num that triggers failure
423+
441424
def flaky_generator(num) -> Iterator[int]:
442425
processed.append(num)
443-
if run_count["value"] == 0 and num == 3:
444-
raise Exception("First failure on num=3")
445-
if run_count["value"] == 1 and num == 5:
446-
raise Exception("Second failure on num=5")
426+
target = fail_on.get(run_count["value"])
427+
if target is not None and num == target:
428+
raise Exception(f"Failure on num={num}")
447429
yield num * 10
448430
yield num * 100
449431

@@ -458,23 +440,21 @@ def flaky_generator(num) -> Iterator[int]:
458440
# -------------- FIRST RUN: Fails on num=3 -------------------
459441
reset_session_job_state()
460442

461-
with pytest.raises(Exception, match="First failure"):
443+
with pytest.raises(Exception, match="Failure on num=3"):
462444
chain.gen(result=flaky_generator, output=int).save("results")
463445

464-
# -------------- SECOND RUN: Continues but fails on num=5 -------------------
446+
# -------------- SECOND RUN: Continues, may or may not hit num=5 -------------------
465447
reset_session_job_state()
466448
processed.clear()
467449
run_count["value"] += 1
468450

469-
with pytest.raises(Exception, match="Second failure"):
451+
try:
452+
chain.gen(result=flaky_generator, output=int).save("results")
453+
except Exception: # noqa: BLE001
454+
reset_session_job_state()
455+
processed.clear()
456+
run_count["value"] += 1
470457
chain.gen(result=flaky_generator, output=int).save("results")
471-
472-
# -------------- THIRD RUN: Finally succeeds -------------------
473-
reset_session_job_state()
474-
processed.clear()
475-
run_count["value"] += 1
476-
477-
chain.gen(result=flaky_generator, output=int).save("results")
478458

479459
# Verify final result is correct (each input produces 2 outputs)
480460
result = dc.read_dataset("results", session=test_session).to_list("result")
@@ -890,3 +870,119 @@ def buggy_agg(num) -> Iterator[int]:
890870

891871
result = dc.read_dataset("agg_results", session=test_session).to_list("total")
892872
assert result == [(21,)]
873+
874+
875+
@skip_if_not_sqlite
876+
def test_continue_udf_preserves_sys_ids(test_session_tmpfile):
877+
"""sys__id must be preserved when copying partial output table on continuation.
878+
879+
If sys__id is stripped during copy, fresh sequential IDs are generated that
880+
don't match the input table's IDs, causing wrong result-to-input pairings
881+
in the join performed by create_result_query.
882+
"""
883+
test_session = test_session_tmpfile
884+
processed = []
885+
886+
dc.read_values(num=[1, 2, 3, 4, 5, 6], session=test_session).save("nums")
887+
888+
def process_buggy(num) -> int:
889+
if len(processed) >= 3:
890+
raise Exception("Simulated failure")
891+
processed.append(num)
892+
return num * 10
893+
894+
chain = dc.read_dataset("nums", session=test_session).settings(batch_size=1)
895+
896+
# -------------- FIRST RUN (crashes after 3 rows) -------------------
897+
reset_session_job_state()
898+
with pytest.raises(Exception, match="Simulated failure"):
899+
chain.map(result=process_buggy, output=int).save("results")
900+
901+
assert len(processed) == 3
902+
903+
# Scramble sys__id to non-sequential values so that the test is deterministic.
904+
# If sys__id is stripped during copy, fresh IDs (1,2,3) won't match the input
905+
# table's scrambled IDs (100,200,300,400,500,600), causing continuation to
906+
# reprocess all rows instead of skipping processed ones.
907+
job = test_session.get_or_create_job()
908+
warehouse_db = test_session.catalog.warehouse.db
909+
all_tables = list(
910+
set(
911+
warehouse_db.list_tables(f"udf_{job.id}%")
912+
+ warehouse_db.list_tables(f"udf_{job.run_group_id}%")
913+
)
914+
)
915+
for table_name in all_tables:
916+
if "_input" in table_name or "_output_partial" in table_name:
917+
tbl = warehouse_db.get_table(table_name)
918+
for i in range(1, 7):
919+
warehouse_db.execute(
920+
tbl.update().where(tbl.c.sys__id == i).values(sys__id=i * 100)
921+
)
922+
923+
# -------------- SECOND RUN (fixed UDF, same function name) -------------------
924+
reset_session_job_state()
925+
processed.clear()
926+
927+
def process_buggy(num) -> int:
928+
processed.append(num)
929+
return num * 10
930+
931+
chain.map(result=process_buggy, output=int).save("results")
932+
933+
result = dc.read_dataset("results", session=test_session).to_list("result")
934+
assert sorted(result) == [(10,), (20,), (30,), (40,), (50,), (60,)]
935+
# Continuation should skip already-processed rows (3 out of 6)
936+
assert len(processed) < 6, (
937+
f"Expected continuation to skip rows, but all {len(processed)} were processed"
938+
)
939+
940+
941+
def test_udf_continue_after_group_by(test_session_tmpfile):
942+
"""UDF continuation works correctly when group_by precedes the UDF.
943+
944+
group_by produces a query with GROUP BY clause that has no sys__id.
945+
The UDF input table gets fresh IDs. On continuation, the partial output
946+
table's sys__id must still match the input table's IDs.
947+
"""
948+
test_session = test_session_tmpfile
949+
processed = []
950+
951+
dc.read_values(
952+
category=["a", "a", "b", "b", "c", "c"],
953+
value=[1, 2, 3, 4, 5, 6],
954+
session=test_session,
955+
).save("data")
956+
957+
def process_buggy(total) -> int:
958+
if len(processed) >= 2:
959+
raise Exception("Simulated failure")
960+
processed.append(total)
961+
return total * 10
962+
963+
chain = (
964+
dc.read_dataset("data", session=test_session)
965+
.group_by(total=func.sum("value"), partition_by="category")
966+
.settings(batch_size=1)
967+
)
968+
969+
# -------------- FIRST RUN (crashes after 2 rows) -------------------
970+
reset_session_job_state()
971+
with pytest.raises(Exception, match="Simulated failure"):
972+
chain.map(result=process_buggy, output=int).save("results")
973+
974+
assert len(processed) == 2
975+
976+
# -------------- SECOND RUN (fixed UDF) -------------------
977+
reset_session_job_state()
978+
processed.clear()
979+
980+
def process_buggy(total) -> int:
981+
processed.append(total)
982+
return total * 10
983+
984+
chain.map(result=process_buggy, output=int).save("results")
985+
986+
result = dc.read_dataset("results", session=test_session).to_list("result")
987+
# group a: 1+2=3 -> 30, group b: 3+4=7 -> 70, group c: 5+6=11 -> 110
988+
assert sorted(result) == [(30,), (70,), (110,)]

tests/unit/lib/test_datachain.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4471,6 +4471,16 @@ def test_save_create_project_not_allowed(test_session, is_studio):
44714471
)
44724472

44734473

4474+
def test_save_regenerates_sys_ids_with_order_by(test_session):
4475+
"""save() regenerates sys__id when chain has order_by to preserve row order."""
4476+
dc.read_values(num=[3, 1, 2], session=test_session).save("source")
4477+
4478+
dc.read_dataset("source", session=test_session).order_by("num").save("sorted")
4479+
4480+
result = dc.read_dataset("sorted", session=test_session).to_list("num")
4481+
assert result == [(1,), (2,), (3,)]
4482+
4483+
44744484
def test_save_raises_in_ephemeral_mode(test_session):
44754485
chain = dc.read_values(num=[1, 2, 3], session=test_session).settings(ephemeral=True)
44764486

0 commit comments

Comments
 (0)