Skip to content

Commit 3b8482c

Browse files
authored
Added a dimension check in MatMulComputeHelper to ensure the inner dimension K_ (#28053)
This pull request improves the handling and testing of 1D vector dot products in the MatMul implementation. It adds stricter validation for the K dimension in 1D vector cases and introduces new regression tests to ensure correct behavior for both valid and invalid input shapes. **Validation Enhancement:** * Added a check in `MatMulComputeHelper` to ensure that when both inputs are 1D vectors, their K dimensions match, and to provide a clear error message if they do not. **Testing Improvements:** * Added new tests in `matmul_integer_test.cc` to verify: - Correct computation for 1D vector dot products with matching K dimensions, for both `uint8_t` and `int8_t` types. - Proper failure and error messaging when 1D vectors with mismatched K dimensions are provided, for both `uint8_t` and `int8_t` types.
1 parent b7804b0 commit 3b8482c

File tree

2 files changed

+248
-0
lines changed

2 files changed

+248
-0
lines changed

onnxruntime/core/providers/cpu/math/matmul_helper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ class MatMulComputeHelper {
168168
if (num_output_dims == 0) {
169169
// for left and right being both vector, output is scalar thus no shape
170170
ORT_RETURN_IF_NOT(M_ == 1 && N_ == 1, "M_ == 1 && N_ == 1 was false");
171+
ORT_RETURN_IF_NOT(K_ == right_shape[0],
172+
"MatMul dimension mismatch. Left vector K (",
173+
K_, ") != right vector K (", right_shape[0], ")");
171174
} else {
172175
if (left_num_dims == 1) {
173176
ORT_RETURN_IF_NOT(num_dims_with_pad - 1 == num_output_dims, "num_dims_with_pad - 1 != num_output_dims");

onnxruntime/test/providers/cpu/math/matmul_integer_test.cc

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,5 +504,250 @@ TEST(MatmulIntegerOpTest, SharedPrepackedWeights) {
504504
}
505505
#endif
506506

507+
// Regression test: 1D vector dot product with matching K dimension should succeed.
508+
// A=[K], B=[K] -> scalar output (dot product).
509+
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_DotProduct) {
510+
OpTester test("MatMulInteger", 10);
511+
test.AddInput<uint8_t>("T1", {4}, {1, 2, 3, 4});
512+
test.AddInput<uint8_t>("T2", {4}, {5, 6, 7, 8});
513+
test.AddInput<uint8_t>("a_zero_point", {}, {0});
514+
test.AddInput<uint8_t>("b_zero_point", {}, {0});
515+
// dot product: 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
516+
test.AddOutput<int32_t>("T3", {}, {70});
517+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
518+
}
519+
520+
// Same 1D vector dot product test with int8_t types.
521+
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_DotProduct_int8) {
522+
OpTester test("MatMulInteger", 10);
523+
test.AddInput<int8_t>("T1", {3}, {1, -2, 3});
524+
test.AddInput<int8_t>("T2", {3}, {4, 5, -6});
525+
test.AddInput<int8_t>("a_zero_point", {}, {0});
526+
test.AddInput<int8_t>("b_zero_point", {}, {0});
527+
// dot product: 1*4 + (-2)*5 + 3*(-6) = 4 - 10 - 18 = -24
528+
test.AddOutput<int32_t>("T3", {}, {-24});
529+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
530+
}
531+
532+
// Regression test: 1D vectors with mismatched K dimension must fail safely.
533+
// Covers prior invalid-shape handling for A=[K], B=[1] where K > 1.
534+
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_KDimensionMismatch) {
535+
OpTester test("MatMulInteger", 10);
536+
test.AddShapeToTensorData(false);
537+
test.AddInput<uint8_t>("T1", {4}, {1, 1, 1, 1});
538+
test.AddInput<uint8_t>("T2", {1}, {5});
539+
test.AddInput<uint8_t>("a_zero_point", {}, {0});
540+
test.AddInput<uint8_t>("b_zero_point", {}, {0});
541+
test.AddOutput<int32_t>("T3", {}, {0});
542+
test.Run(OpTester::ExpectResult::kExpectFailure, "MatMul dimension mismatch", {kDmlExecutionProvider});
543+
}
544+
545+
// Same regression test with int8_t types.
546+
TEST(MatmulIntegerOpTest, MatMulInteger_int8_1D_Vector_KDimensionMismatch) {
547+
OpTester test("MatMulInteger", 10);
548+
test.AddShapeToTensorData(false);
549+
test.AddInput<int8_t>("T1", {8}, {1, 1, 1, 1, 1, 1, 1, 1});
550+
test.AddInput<int8_t>("T2", {1}, {5});
551+
test.AddInput<int8_t>("a_zero_point", {}, {0});
552+
test.AddInput<int8_t>("b_zero_point", {}, {0});
553+
test.AddOutput<int32_t>("T3", {}, {0});
554+
test.Run(OpTester::ExpectResult::kExpectFailure, "MatMul dimension mismatch", {kDmlExecutionProvider});
555+
}
556+
557+
// 1D dot product: uint8 A x int8 B (mixed types).
558+
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_DotProduct_uint8_int8) {
559+
OpTester test("MatMulInteger", 10);
560+
test.AddInput<uint8_t>("T1", {3}, {10, 20, 30});
561+
test.AddInput<int8_t>("T2", {3}, {1, -1, 2});
562+
test.AddInput<uint8_t>("a_zero_point", {}, {0});
563+
test.AddInput<int8_t>("b_zero_point", {}, {0});
564+
// dot product: 10*1 + 20*(-1) + 30*2 = 10 - 20 + 60 = 50
565+
test.AddOutput<int32_t>("T3", {}, {50});
566+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
567+
}
568+
569+
// 1D dot product: int8 A x uint8 B (mixed types).
570+
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_DotProduct_int8_uint8) {
571+
OpTester test("MatMulInteger", 10);
572+
test.AddInput<int8_t>("T1", {3}, {-1, 2, -3});
573+
test.AddInput<uint8_t>("T2", {3}, {10, 20, 30});
574+
test.AddInput<int8_t>("a_zero_point", {}, {0});
575+
test.AddInput<uint8_t>("b_zero_point", {}, {0});
576+
// dot product: (-1)*10 + 2*20 + (-3)*30 = -10 + 40 - 90 = -60
577+
test.AddOutput<int32_t>("T3", {}, {-60});
578+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
579+
}
580+
581+
// int8 A x uint8 B, 2D with zero points.
582+
TEST(MatmulIntegerOpTest, MatMulInteger_int8_uint8_2D_WithZeroPoints) {
583+
OpTester test("MatMulInteger", 10);
584+
test.AddInput<int8_t>("T1",
585+
{2, 3},
586+
{-3, 7, 5,
587+
4, -5, 8});
588+
test.AddInput<uint8_t>("T2",
589+
{3, 2},
590+
{15, 22,
591+
17, 12,
592+
19, 16});
593+
test.AddInput<int8_t>("a_zero_point", {}, {2});
594+
test.AddInput<uint8_t>("b_zero_point", {}, {10});
595+
// A_offset = A - 2: {-5, 5, 3, 2, -7, 6}
596+
// B_offset = B - 10: {5, 12, 7, 2, 9, 6}
597+
// C[0,0] = (-5)*5 + 5*7 + 3*9 = -25 + 35 + 27 = 37
598+
// C[0,1] = (-5)*12 + 5*2 + 3*6 = -60 + 10 + 18 = -32
599+
// C[1,0] = 2*5 + (-7)*7 + 6*9 = 10 - 49 + 54 = 15
600+
// C[1,1] = 2*12 + (-7)*2 + 6*6 = 24 - 14 + 36 = 46
601+
test.AddOutput<int32_t>("T3",
602+
{2, 2},
603+
{37, -32,
604+
15, 46});
605+
test.Run();
606+
}
607+
608+
// int8 A x uint8 B, A=ND B=ND with broadcasting.
609+
TEST(MatmulIntegerOpTest, MatMulInteger_int8_uint8_A_ND_B_ND) {
610+
OpTester test("MatMulInteger", 10);
611+
test.AddInput<int8_t>("T1",
612+
{2, 2, 3},
613+
{1, -2, 3,
614+
-1, 4, -2,
615+
616+
1, -2, 3,
617+
-1, 4, -2});
618+
test.AddInput<uint8_t>("T2",
619+
{2, 3, 2},
620+
{10, 20,
621+
30, 40,
622+
50, 60,
623+
624+
10, 20,
625+
30, 40,
626+
50, 60});
627+
test.AddInput<int8_t>("a_zero_point", {}, {0});
628+
test.AddInput<uint8_t>("b_zero_point", {}, {0});
629+
// batch 0: [[1,-2,3],[−1,4,−2]] x [[10,20],[30,40],[50,60]]
630+
// [0,0] = 1*10+(-2)*30+3*50 = 10-60+150 = 100
631+
// [0,1] = 1*20+(-2)*40+3*60 = 20-80+180 = 120
632+
// [1,0] = (-1)*10+4*30+(-2)*50 = -10+120-100 = 10
633+
// [1,1] = (-1)*20+4*40+(-2)*60 = -20+160-120 = 20
634+
test.AddOutput<int32_t>("T3",
635+
{2, 2, 2},
636+
{100, 120,
637+
10, 20,
638+
639+
100, 120,
640+
10, 20});
641+
test.Run();
642+
}
643+
644+
// int8 A x uint8 B, A=ND B=2D (broadcast B across batches).
645+
TEST(MatmulIntegerOpTest, MatMulInteger_int8_uint8_A_ND_B_2D) {
646+
OpTester test("MatMulInteger", 10);
647+
test.AddInput<int8_t>("T1",
648+
{2, 2, 3},
649+
{1, -2, 3,
650+
-1, 4, -2,
651+
652+
2, 0, -1,
653+
0, 3, 1});
654+
test.AddInput<uint8_t>("T2",
655+
{3, 2},
656+
{10, 20,
657+
30, 40,
658+
50, 60});
659+
test.AddInput<int8_t>("a_zero_point", {}, {0});
660+
test.AddInput<uint8_t>("b_zero_point", {}, {0});
661+
// batch 0: same as above = {100,120,10,20}
662+
// batch 1: [[2,0,-1],[0,3,1]] x [[10,20],[30,40],[50,60]]
663+
// [0,0] = 2*10+0*30+(-1)*50 = 20-50 = -30
664+
// [0,1] = 2*20+0*40+(-1)*60 = 40-60 = -20
665+
// [1,0] = 0*10+3*30+1*50 = 90+50 = 140
666+
// [1,1] = 0*20+3*40+1*60 = 120+60 = 180
667+
test.AddOutput<int32_t>("T3",
668+
{2, 2, 2},
669+
{100, 120,
670+
10, 20,
671+
672+
-30, -20,
673+
140, 180});
674+
test.Run();
675+
}
676+
677+
// [M x N] = [M x K] x [K x N] with int8 A, parameterized on B type.
678+
template <typename WeightType>
679+
void RunMatMulIntegerS8X8Test(const int M, const int N, const int K, bool B_is_initializer) {
680+
OpTester test("MatMulInteger", 10);
681+
static std::default_random_engine e(456);
682+
static std::uniform_int_distribution<int> n_signed(-64, 63);
683+
static std::uniform_int_distribution<int> n_xint8(std::numeric_limits<WeightType>::min(), std::numeric_limits<WeightType>::max());
684+
685+
Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M)
686+
.unaryExpr([](int) { return n_signed(e); });
687+
std::vector<int8_t> matrix_a_data = ToVector<int8_t>(matrix_a.data(), M * K);
688+
int8_t a_zero_point = 0;
689+
Eigen::MatrixXi matrix_a_offset = matrix_a - a_zero_point * Eigen::MatrixXi::Ones(K, M);
690+
691+
Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K)
692+
.unaryExpr([](int) { return n_xint8(e); });
693+
std::vector<WeightType> matrix_b_data = ToVector<WeightType>(matrix_b.data(), N * K);
694+
WeightType b_zero_point = 0;
695+
Eigen::MatrixXi b_zp_matrix = b_zero_point * Eigen::MatrixXi::Ones(N, K);
696+
Eigen::MatrixXi matrix_c = ((matrix_b - b_zp_matrix) * matrix_a_offset).eval();
697+
698+
test.AddInput<int8_t>("T1", {M, K}, std::move(matrix_a_data));
699+
test.AddInput<WeightType>("T2", {K, N}, std::move(matrix_b_data), B_is_initializer);
700+
701+
test.AddOutput<int32_t>("T3", {M, N}, ToVector<int32_t>(matrix_c.data(), M * N));
702+
test.Run();
703+
}
704+
705+
void RunMatMulIntegerS8X8TestBatch(const int M, const int N, const int K) {
706+
RunMatMulIntegerS8X8Test<int8_t>(M, N, K, false);
707+
RunMatMulIntegerS8X8Test<int8_t>(M, N, K, true);
708+
RunMatMulIntegerS8X8Test<uint8_t>(M, N, K, false);
709+
RunMatMulIntegerS8X8Test<uint8_t>(M, N, K, true);
710+
}
711+
712+
TEST(MatmulIntegerOpTest, MatMulInteger_Int8_X8_Scalar) {
713+
RunMatMulIntegerS8X8TestBatch(1, 1, 32);
714+
RunMatMulIntegerS8X8TestBatch(1, 1, 260);
715+
RunMatMulIntegerS8X8TestBatch(1, 1, 288);
716+
}
717+
718+
TEST(MatmulIntegerOpTest, MatMulInteger_Int8_X8_GEMV) {
719+
RunMatMulIntegerS8X8TestBatch(1, 2, 16);
720+
RunMatMulIntegerS8X8TestBatch(1, 2, 64);
721+
RunMatMulIntegerS8X8TestBatch(1, 8, 36);
722+
RunMatMulIntegerS8X8TestBatch(1, 8, 68);
723+
RunMatMulIntegerS8X8TestBatch(1, 8, 400);
724+
}
725+
726+
TEST(MatmulIntegerOpTest, MatMulInteger_Int8_X8_GEMM) {
727+
RunMatMulIntegerS8X8TestBatch(2, 2, 40);
728+
RunMatMulIntegerS8X8TestBatch(2, 48, 33);
729+
RunMatMulIntegerS8X8TestBatch(2, 51, 40);
730+
RunMatMulIntegerS8X8TestBatch(4, 8, 68);
731+
}
732+
733+
// Per-row A zero point is not supported by the CPU implementation.
734+
// Verify it is rejected gracefully.
735+
TEST(MatmulIntegerOpTest, MatMulInteger_PerRow_A_ZeroPoint_Rejected) {
736+
OpTester test("MatMulInteger", 10);
737+
test.AddShapeToTensorData(false);
738+
test.AddInput<uint8_t>("T1",
739+
{2, 3},
740+
{11, 7, 3,
741+
10, 6, 2});
742+
test.AddInput<uint8_t>("T2",
743+
{3, 2},
744+
{1, 4, 2, 5, 3, 6});
745+
// per-row A zero point: shape {2, 1} — not scalar
746+
test.AddInput<uint8_t>("a_zero_point", {2, 1}, {12, 10});
747+
test.AddInput<uint8_t>("b_zero_point", {}, {0});
748+
test.AddOutput<int32_t>("T3", {2, 2}, {0, 0, 0, 0});
749+
test.Run(OpTester::ExpectResult::kExpectFailure, "", {kDmlExecutionProvider});
750+
}
751+
507752
} // namespace test
508753
} // namespace onnxruntime

0 commit comments

Comments
 (0)