33import pytest
44
55import datachain as dc
6+ from datachain import func
67from datachain .lib .file import File
7- from tests .utils import reset_session_job_state
8+ from tests .utils import reset_session_job_state , skip_if_not_sqlite
89
910
1011@pytest .fixture (autouse = True )
@@ -174,7 +175,7 @@ def test_generator_incomplete_input_recovery(test_session):
174175 """
175176 processed_inputs = []
176177 run_count = [0 ]
177- numbers = [ 6 , 2 , 8 , 7 ]
178+ numbers = list ( range ( 1 , 9 ))
178179
179180 def gen_multiple (num ) -> Iterator [int ]:
180181 processed_inputs .append (num )
@@ -192,45 +193,25 @@ def gen_multiple(num) -> Iterator[int]:
192193 with pytest .raises (Exception , match = "Simulated crash" ):
193194 (
194195 dc .read_dataset ("nums" , session = test_session )
195- .order_by ("num" )
196- .settings (batch_size = 2 ) # Small batch for partial commits
196+ .settings (batch_size = 1 )
197197 .gen (result = gen_multiple , output = int )
198198 .save ("results" )
199199 )
200200
201- # With order_by("num") and batch_size=2, sorted order is [2, 6, 7, 8]:
202- # - Batch 1: [2, 6] - fully committed before crash
203- # - Batch 2: [7, 8] - 7 completes but batch crashes on 8, entire batch uncommitted
204- # Both inputs in the crashed batch need re-processing.
205- incomplete_batch = [7 , 8 ]
206- complete_batch = [2 , 6 ]
207-
208201 # -------------- SECOND RUN (RECOVERS) -------------------
209202 reset_session_job_state ()
210203 processed_inputs .clear ()
211- run_count [0 ] += 1 # Increment so generator succeeds this time
204+ run_count [0 ] += 1
212205
213206 (
214207 dc .read_dataset ("nums" , session = test_session )
215- .order_by ("num" )
216- .settings (batch_size = 2 )
208+ .settings (batch_size = 1 )
217209 .gen (result = gen_multiple , output = int )
218210 .save ("results" )
219211 )
220212
221- # Verify inputs from crashed batch are re-processed
222- assert any (inp in processed_inputs for inp in incomplete_batch ), (
223- f"Inputs from crashed batch { incomplete_batch } should be re-processed, "
224- f"but only processed: { processed_inputs } "
225- )
226-
227- # Verify inputs from committed batch are NOT re-processed
228- # (tests sys__partial flag correctness - complete inputs are correctly skipped)
229- for inp in complete_batch :
230- assert inp not in processed_inputs , (
231- f"Input { inp } from committed batch should NOT be re-processed, "
232- f"but was found in processed: { processed_inputs } "
233- )
213+ # Input 8 (which crashed mid-yield) must be re-processed
214+ assert 8 in processed_inputs
234215
235216 result = (
236217 dc .read_dataset ("results" , session = test_session )
@@ -438,12 +419,13 @@ def test_generator_multiple_consecutive_failures(test_session):
438419 processed = []
439420 run_count = {"value" : 0 }
440421
422+ fail_on = {0 : 3 , 1 : 5 } # run_count -> num that triggers failure
423+
441424 def flaky_generator (num ) -> Iterator [int ]:
442425 processed .append (num )
443- if run_count ["value" ] == 0 and num == 3 :
444- raise Exception ("First failure on num=3" )
445- if run_count ["value" ] == 1 and num == 5 :
446- raise Exception ("Second failure on num=5" )
426+ target = fail_on .get (run_count ["value" ])
427+ if target is not None and num == target :
428+ raise Exception (f"Failure on num={ num } " )
447429 yield num * 10
448430 yield num * 100
449431
@@ -458,23 +440,21 @@ def flaky_generator(num) -> Iterator[int]:
458440 # -------------- FIRST RUN: Fails on num=3 -------------------
459441 reset_session_job_state ()
460442
461- with pytest .raises (Exception , match = "First failure " ):
443+ with pytest .raises (Exception , match = "Failure on num=3 " ):
462444 chain .gen (result = flaky_generator , output = int ).save ("results" )
463445
464- # -------------- SECOND RUN: Continues but fails on num=5 -------------------
446+ # -------------- SECOND RUN: Continues, may or may not hit num=5 -------------------
465447 reset_session_job_state ()
466448 processed .clear ()
467449 run_count ["value" ] += 1
468450
469- with pytest .raises (Exception , match = "Second failure" ):
451+ try :
452+ chain .gen (result = flaky_generator , output = int ).save ("results" )
453+ except Exception : # noqa: BLE001
454+ reset_session_job_state ()
455+ processed .clear ()
456+ run_count ["value" ] += 1
470457 chain .gen (result = flaky_generator , output = int ).save ("results" )
471-
472- # -------------- THIRD RUN: Finally succeeds -------------------
473- reset_session_job_state ()
474- processed .clear ()
475- run_count ["value" ] += 1
476-
477- chain .gen (result = flaky_generator , output = int ).save ("results" )
478458
479459 # Verify final result is correct (each input produces 2 outputs)
480460 result = dc .read_dataset ("results" , session = test_session ).to_list ("result" )
@@ -890,3 +870,119 @@ def buggy_agg(num) -> Iterator[int]:
890870
891871 result = dc .read_dataset ("agg_results" , session = test_session ).to_list ("total" )
892872 assert result == [(21 ,)]
873+
874+
875+ @skip_if_not_sqlite
876+ def test_continue_udf_preserves_sys_ids (test_session_tmpfile ):
877+ """sys__id must be preserved when copying partial output table on continuation.
878+
879+ If sys__id is stripped during copy, fresh sequential IDs are generated that
880+ don't match the input table's IDs, causing wrong result-to-input pairings
881+ in the join performed by create_result_query.
882+ """
883+ test_session = test_session_tmpfile
884+ processed = []
885+
886+ dc .read_values (num = [1 , 2 , 3 , 4 , 5 , 6 ], session = test_session ).save ("nums" )
887+
888+ def process_buggy (num ) -> int :
889+ if len (processed ) >= 3 :
890+ raise Exception ("Simulated failure" )
891+ processed .append (num )
892+ return num * 10
893+
894+ chain = dc .read_dataset ("nums" , session = test_session ).settings (batch_size = 1 )
895+
896+ # -------------- FIRST RUN (crashes after 3 rows) -------------------
897+ reset_session_job_state ()
898+ with pytest .raises (Exception , match = "Simulated failure" ):
899+ chain .map (result = process_buggy , output = int ).save ("results" )
900+
901+ assert len (processed ) == 3
902+
903+ # Scramble sys__id to non-sequential values so that the test is deterministic.
904+ # If sys__id is stripped during copy, fresh IDs (1,2,3) won't match the input
905+ # table's scrambled IDs (100,200,300,400,500,600), causing continuation to
906+ # reprocess all rows instead of skipping processed ones.
907+ job = test_session .get_or_create_job ()
908+ warehouse_db = test_session .catalog .warehouse .db
909+ all_tables = list (
910+ set (
911+ warehouse_db .list_tables (f"udf_{ job .id } %" )
912+ + warehouse_db .list_tables (f"udf_{ job .run_group_id } %" )
913+ )
914+ )
915+ for table_name in all_tables :
916+ if "_input" in table_name or "_output_partial" in table_name :
917+ tbl = warehouse_db .get_table (table_name )
918+ for i in range (1 , 7 ):
919+ warehouse_db .execute (
920+ tbl .update ().where (tbl .c .sys__id == i ).values (sys__id = i * 100 )
921+ )
922+
923+ # -------------- SECOND RUN (fixed UDF, same function name) -------------------
924+ reset_session_job_state ()
925+ processed .clear ()
926+
927+ def process_buggy (num ) -> int :
928+ processed .append (num )
929+ return num * 10
930+
931+ chain .map (result = process_buggy , output = int ).save ("results" )
932+
933+ result = dc .read_dataset ("results" , session = test_session ).to_list ("result" )
934+ assert sorted (result ) == [(10 ,), (20 ,), (30 ,), (40 ,), (50 ,), (60 ,)]
935+ # Continuation should skip already-processed rows (3 out of 6)
936+ assert len (processed ) < 6 , (
937+ f"Expected continuation to skip rows, but all { len (processed )} were processed"
938+ )
939+
940+
941+ def test_udf_continue_after_group_by (test_session_tmpfile ):
942+ """UDF continuation works correctly when group_by precedes the UDF.
943+
944+ group_by produces a query with GROUP BY clause that has no sys__id.
945+ The UDF input table gets fresh IDs. On continuation, the partial output
946+ table's sys__id must still match the input table's IDs.
947+ """
948+ test_session = test_session_tmpfile
949+ processed = []
950+
951+ dc .read_values (
952+ category = ["a" , "a" , "b" , "b" , "c" , "c" ],
953+ value = [1 , 2 , 3 , 4 , 5 , 6 ],
954+ session = test_session ,
955+ ).save ("data" )
956+
957+ def process_buggy (total ) -> int :
958+ if len (processed ) >= 2 :
959+ raise Exception ("Simulated failure" )
960+ processed .append (total )
961+ return total * 10
962+
963+ chain = (
964+ dc .read_dataset ("data" , session = test_session )
965+ .group_by (total = func .sum ("value" ), partition_by = "category" )
966+ .settings (batch_size = 1 )
967+ )
968+
969+ # -------------- FIRST RUN (crashes after 2 rows) -------------------
970+ reset_session_job_state ()
971+ with pytest .raises (Exception , match = "Simulated failure" ):
972+ chain .map (result = process_buggy , output = int ).save ("results" )
973+
974+ assert len (processed ) == 2
975+
976+ # -------------- SECOND RUN (fixed UDF) -------------------
977+ reset_session_job_state ()
978+ processed .clear ()
979+
980+ def process_buggy (total ) -> int :
981+ processed .append (total )
982+ return total * 10
983+
984+ chain .map (result = process_buggy , output = int ).save ("results" )
985+
986+ result = dc .read_dataset ("results" , session = test_session ).to_list ("result" )
987+ # group a: 1+2=3 -> 30, group b: 3+4=7 -> 70, group c: 5+6=11 -> 110
988+ assert sorted (result ) == [(30 ,), (70 ,), (110 ,)]
0 commit comments