@@ -38,5 +38,306 @@ TEST(BifurcationDetectorTest, Test2) {
3838 tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
3939}
4040
41+ // Bifurcation at the first predicted token (immediate mismatch).
42+ // pred_tokens[0] != src_tokens[prev_suffix_match_idx] → pred_bifur_idx = 0.
43+ // Output = cur_tokens + pred_tokens[0].
44+ TEST (BifurcationDetectorTest, BifurcationAtFirstToken) {
45+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
46+
47+ // src=[1,2,3], prev_idx=0, pred=[99,2,3,0] (pred[0]=99 != src[0]=1).
48+ tester.AddInput <int64_t >(" src_tokens" , {3 }, {1 , 2 , 3 });
49+ tester.AddInput <int64_t >(" cur_tokens" , {2 }, {10 , 20 });
50+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
51+ tester.AddInput <int64_t >(" pred_tokens" , {4 }, {99 , 2 , 3 , 0 });
52+ // pred_bifur_idx = 0, output = [10, 20] + [99] = [10, 20, 99]
53+ tester.AddOutput <int64_t >(" tokens" , {3 }, {10 , 20 , 99 });
54+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
55+
56+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
57+ execution_providers.push_back (DefaultCpuExecutionProvider ());
58+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
59+ }
60+
61+ // Bifurcation in the middle of the predicted sequence.
62+ TEST (BifurcationDetectorTest, BifurcationMidSequence) {
63+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
64+
65+ // src=[10,20,30,40], prev_idx=0, pred=[10,20,99,40,0].
66+ // Match at pred[0]=10==src[0], pred[1]=20==src[1], pred[2]=99!=src[2]=30.
67+ // pred_bifur_idx = 2. Output = cur + pred[0..2] = [5] + [10,20,99].
68+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {10 , 20 , 30 , 40 });
69+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {5 });
70+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
71+ tester.AddInput <int64_t >(" pred_tokens" , {5 }, {10 , 20 , 99 , 40 , 0 });
72+ tester.AddOutput <int64_t >(" tokens" , {4 }, {5 , 10 , 20 , 99 });
73+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
74+
75+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
76+ execution_providers.push_back (DefaultCpuExecutionProvider ());
77+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
78+ }
79+
80+ // Non-zero prev_suffix_match_idx with pred_tokens: bifurcation scan starts
81+ // partway through src_tokens.
82+ TEST (BifurcationDetectorTest, NonZeroPrevSuffixMatchIdx) {
83+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
84+
85+ // src=[10,20,30,40,50], prev_idx=2.
86+ // pred_tokens_len must be 5 + 1 - 2 = 4.
87+ // Compare: pred[0] vs src[2]=30, pred[1] vs src[3]=40, pred[2] vs src[4]=50.
88+ // pred=[30,40,99,0] → match at 0,1; mismatch at 2. pred_bifur_idx=2.
89+ // Output = [5] + [30,40,99].
90+ tester.AddInput <int64_t >(" src_tokens" , {5 }, {10 , 20 , 30 , 40 , 50 });
91+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {5 });
92+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {2 });
93+ tester.AddInput <int64_t >(" pred_tokens" , {4 }, {30 , 40 , 99 , 0 });
94+ tester.AddOutput <int64_t >(" tokens" , {4 }, {5 , 30 , 40 , 99 });
95+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
96+
97+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
98+ execution_providers.push_back (DefaultCpuExecutionProvider ());
99+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
100+ }
101+
102+ // Suffix matching: multiple occurrences of the 1-gram cause suffix_idx = -1,
103+ // then the 2-gram is unique → suffix_idx reports the 2-gram match position.
104+ TEST (BifurcationDetectorTest, SuffixMatchMultipleSingleGramUniqueDigram) {
105+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
106+
107+ // src=[1,3,4,2,1,4,0], cur=[5,1,4]. No pred → output = [5,1,4].
108+ // 1-gram [4]: found at src[2] and src[5] → multiple → -1, continue.
109+ // 2-gram [1,4]: found at src[4..5]. suffix_idx=4+2=6. No second match → unique.
110+ tester.AddInput <int64_t >(" src_tokens" , {7 }, {1 , 3 , 4 , 2 , 1 , 4 , 0 });
111+ tester.AddInput <int64_t >(" cur_tokens" , {3 }, {5 , 1 , 4 });
112+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
113+ // No pred_tokens → output = cur_tokens = [5, 1, 4].
114+ tester.AddOutput <int64_t >(" tokens" , {3 }, {5 , 1 , 4 });
115+ // 1-gram [4]: multiple matches → -1, continue.
116+ // 2-gram [1,4]: unique match at src[4..5], suffix_idx = 4+2 = 6.
117+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {6 });
118+
119+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
120+ execution_providers.push_back (DefaultCpuExecutionProvider ());
121+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
122+ }
123+
124+ // Suffix matching: suffix_idx >= src_tokens_len causes an early break after assignment,
125+ // so this edge case returns the assigned suffix_idx, not -1.
126+ TEST (BifurcationDetectorTest, SuffixMatchAtEndOfSrc) {
127+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
128+
129+ // src=[1,2,3], cur=[5,3].
130+ // 1-gram: [3]. Found at src[2]. suffix_idx = 2+1 = 3 >= 3 → break.
131+ // suffix_idx was already assigned 3 before the break, so the result is 3.
132+ tester.AddInput <int64_t >(" src_tokens" , {3 }, {1 , 2 , 3 });
133+ tester.AddInput <int64_t >(" cur_tokens" , {2 }, {5 , 3 });
134+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
135+ tester.AddOutput <int64_t >(" tokens" , {2 }, {5 , 3 });
136+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {3 });
137+
138+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
139+ execution_providers.push_back (DefaultCpuExecutionProvider ());
140+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
141+ }
142+
143+ // Suffix matching: n-gram size exceeds output token count → early break.
144+ TEST (BifurcationDetectorTest, SuffixNgramExceedsOutputLen) {
145+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
146+ tester.AddAttribute <int64_t >(" min_ngram_size" , int64_t (5 ));
147+ tester.AddAttribute <int64_t >(" max_ngram_size" , int64_t (7 ));
148+
149+ // Output has only 2 tokens, but min_ngram_size=5. The loop immediately breaks
150+ // because i=5 > tokens_len=2. suffix_idx stays -1.
151+ tester.AddInput <int64_t >(" src_tokens" , {10 }, {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 });
152+ tester.AddInput <int64_t >(" cur_tokens" , {2 }, {5 , 3 });
153+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
154+ tester.AddOutput <int64_t >(" tokens" , {2 }, {5 , 3 });
155+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
156+
157+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
158+ execution_providers.push_back (DefaultCpuExecutionProvider ());
159+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
160+ }
161+
162+ // Custom min/max_ngram_size: min=2, max=2. Only 2-grams are checked.
163+ TEST (BifurcationDetectorTest, CustomNgramSize) {
164+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
165+ tester.AddAttribute <int64_t >(" min_ngram_size" , int64_t (2 ));
166+ tester.AddAttribute <int64_t >(" max_ngram_size" , int64_t (2 ));
167+
168+ // src=[1,2,3,4,5], cur=[7,3,4].
169+ // With default min=1: 1-gram [4] found at src[3], suffix_idx=4, unique → return 4.
170+ // With min=max=2: only 2-gram [3,4] is checked. Found at src[2..3], suffix_idx=2+2=4. unique → return 4.
171+ // Same result here but exercises the attribute path.
172+ tester.AddInput <int64_t >(" src_tokens" , {5 }, {1 , 2 , 3 , 4 , 5 });
173+ tester.AddInput <int64_t >(" cur_tokens" , {3 }, {7 , 3 , 4 });
174+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
175+ tester.AddOutput <int64_t >(" tokens" , {3 }, {7 , 3 , 4 });
176+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {4 });
177+
178+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
179+ execution_providers.push_back (DefaultCpuExecutionProvider ());
180+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
181+ }
182+
183+ // Combined: non-zero prev_suffix_match_idx with pred_tokens AND suffix match.
184+ // Exercises both major code paths together.
185+ TEST (BifurcationDetectorTest, BifurcationAndSuffixMatchCombined) {
186+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
187+
188+ // src=[10,20,30,40,50,60], prev_idx=3.
189+ // pred_tokens_len = 6 + 1 - 3 = 4.
190+ // Compare pred vs src starting at offset 3: pred[0] vs src[3]=40, pred[1] vs src[4]=50, pred[2] vs src[5]=60.
191+ // pred=[40,50,99,0]. Match at 0,1; mismatch at 2. pred_bifur_idx=2.
192+ // Output = cur + pred[0..2] = [5, 8] + [40, 50, 99] = [5, 8, 40, 50, 99].
193+ //
194+ // Suffix matching on output=[5,8,40,50,99] against src=[10,20,30,40,50,60]:
195+ // 1-gram: [99]. Not in src → break. suffix_idx=-1.
196+ tester.AddInput <int64_t >(" src_tokens" , {6 }, {10 , 20 , 30 , 40 , 50 , 60 });
197+ tester.AddInput <int64_t >(" cur_tokens" , {2 }, {5 , 8 });
198+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {3 });
199+ tester.AddInput <int64_t >(" pred_tokens" , {4 }, {40 , 50 , 99 , 0 });
200+ tester.AddOutput <int64_t >(" tokens" , {5 }, {5 , 8 , 40 , 50 , 99 });
201+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
202+
203+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
204+ execution_providers.push_back (DefaultCpuExecutionProvider ());
205+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
206+ }
207+
208+ // Verify that a negative prev_suffix_match_idx is rejected.
209+ TEST (BifurcationDetectorTest, NegativePrevSuffixMatchIdx) {
210+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
211+
212+ // src_tokens has 4 elements. With prev_suffix_match_idx = -1,
213+ // pred_tokens_len must satisfy: pred_tokens_len == src_tokens_len + 1 - (-1) = 6
214+ // The negative index must be caught before it is used as an array offset.
215+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {1 , 5 , 3 , 4 });
216+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {2 });
217+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {-1 });
218+ tester.AddInput <int64_t >(" pred_tokens" , {6 }, {1 , 5 , 3 , 4 , 2 , 7 });
219+ tester.AddOutput <int64_t >(" tokens" , {1 }, {0 });
220+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {0 });
221+
222+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
223+ execution_providers.push_back (DefaultCpuExecutionProvider ());
224+ tester.Run (OpTester::ExpectResult::kExpectFailure ,
225+ " prev_suffix_match_idx must be non-negative" ,
226+ {}, nullptr , &execution_providers);
227+ }
228+
229+ // Verify that a large negative prev_suffix_match_idx is also rejected.
230+ TEST (BifurcationDetectorTest, LargeNegativePrevSuffixMatchIdx) {
231+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
232+
233+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {1 , 5 , 3 , 4 });
234+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {2 });
235+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {-100 });
236+ tester.AddInput <int64_t >(" pred_tokens" , {105 }, std::vector<int64_t >(105 , 0 ));
237+ tester.AddOutput <int64_t >(" tokens" , {1 }, {0 });
238+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {0 });
239+
240+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
241+ execution_providers.push_back (DefaultCpuExecutionProvider ());
242+ tester.Run (OpTester::ExpectResult::kExpectFailure ,
243+ " prev_suffix_match_idx must be non-negative" ,
244+ {}, nullptr , &execution_providers);
245+ }
246+
247+ // Verify prev_suffix_match_idx exceeding src_tokens_len is rejected.
248+ TEST (BifurcationDetectorTest, PrevSuffixMatchIdxExceedsSrcLen) {
249+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
250+
251+ // src_tokens_len = 4, prev_suffix_match_idx = 5 should fail the upper-bound check.
252+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {1 , 5 , 3 , 4 });
253+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {2 });
254+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {5 });
255+ tester.AddInput <int64_t >(" pred_tokens" , {1 }, {7 });
256+ tester.AddOutput <int64_t >(" tokens" , {1 }, {0 });
257+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {0 });
258+
259+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
260+ execution_providers.push_back (DefaultCpuExecutionProvider ());
261+ tester.Run (OpTester::ExpectResult::kExpectFailure ,
262+ " prev_suffix_match_idx must not exceed src_tokens length" ,
263+ {}, nullptr , &execution_providers);
264+ }
265+
266+ // No pred_tokens — output should equal cur_tokens.
267+ TEST (BifurcationDetectorTest, NoPredTokens) {
268+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
269+
270+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {1 , 5 , 3 , 4 });
271+ tester.AddInput <int64_t >(" cur_tokens" , {3 }, {10 , 20 , 30 });
272+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
273+ tester.AddOutput <int64_t >(" tokens" , {3 }, {10 , 20 , 30 });
274+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
275+
276+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
277+ execution_providers.push_back (DefaultCpuExecutionProvider ());
278+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
279+ }
280+
281+ // prev_suffix_match_idx at the boundary (equal to src_tokens_len).
282+ TEST (BifurcationDetectorTest, PrevSuffixMatchIdxAtBoundary) {
283+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
284+
285+ // prev_suffix_match_idx = 4 = src_tokens_len.
286+ // pred_tokens_len must be src_tokens_len + 1 - 4 = 1.
287+ // Loop doesn't execute (bound = 0), pred_bifur_idx = 0.
288+ // Output = cur_tokens + pred_tokens[0..0].
289+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {1 , 5 , 3 , 4 });
290+ tester.AddInput <int64_t >(" cur_tokens" , {2 }, {10 , 20 });
291+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {4 });
292+ tester.AddInput <int64_t >(" pred_tokens" , {1 }, {99 });
293+ tester.AddOutput <int64_t >(" tokens" , {3 }, {10 , 20 , 99 });
294+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
295+
296+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
297+ execution_providers.push_back (DefaultCpuExecutionProvider ());
298+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
299+ }
300+
301+ // All predicted tokens match source tokens — no bifurcation occurs.
302+ // pred_bifur_idx reaches the loop bound (src_tokens_len - prev_suffix_match_idx).
303+ // Output = cur_tokens + all pred_tokens.
304+ TEST (BifurcationDetectorTest, FullMatchNoBifurcation) {
305+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
306+
307+ // src=[10,20,30], prev_idx=0, pred must have len = 3+1-0 = 4.
308+ // pred=[10,20,30,99]. Loop: pred[0]==src[0], pred[1]==src[1], pred[2]==src[2].
309+ // pred_bifur_idx = 3 (loop bound). Output = [5] + pred[0..3] = [5,10,20,30,99].
310+ tester.AddInput <int64_t >(" src_tokens" , {3 }, {10 , 20 , 30 });
311+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {5 });
312+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
313+ tester.AddInput <int64_t >(" pred_tokens" , {4 }, {10 , 20 , 30 , 99 });
314+ tester.AddOutput <int64_t >(" tokens" , {5 }, {5 , 10 , 20 , 30 , 99 });
315+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {-1 });
316+
317+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
318+ execution_providers.push_back (DefaultCpuExecutionProvider ());
319+ tester.Run (OpTester::ExpectResult::kExpectSuccess , " " , {}, nullptr , &execution_providers);
320+ }
321+
322+ // pred_tokens length does not match the expected (src_tokens_len + 1 - prev_suffix_match_idx).
323+ TEST (BifurcationDetectorTest, PredTokensLengthMismatch) {
324+ OpTester tester (" BifurcationDetector" , 1 , onnxruntime::kMSDomain );
325+
326+ // src_tokens_len=4, prev_suffix_match_idx=0 → expected pred_tokens_len = 5.
327+ // Provide pred_tokens_len = 3 to trigger the mismatch check.
328+ tester.AddInput <int64_t >(" src_tokens" , {4 }, {1 , 5 , 3 , 4 });
329+ tester.AddInput <int64_t >(" cur_tokens" , {1 }, {2 });
330+ tester.AddInput <int64_t >(" prev_suffix_match_idx" , {}, {0 });
331+ tester.AddInput <int64_t >(" pred_tokens" , {3 }, {1 , 5 , 3 });
332+ tester.AddOutput <int64_t >(" tokens" , {1 }, {0 });
333+ tester.AddOutput <int64_t >(" suffix_match_idx" , {}, {0 });
334+
335+ std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
336+ execution_providers.push_back (DefaultCpuExecutionProvider ());
337+ tester.Run (OpTester::ExpectResult::kExpectFailure ,
338+ " pred_tokens length mismatch" ,
339+ {}, nullptr , &execution_providers);
340+ }
341+
41342} // namespace test
42343} // namespace onnxruntime
0 commit comments