Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cpu/math/matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class MatMulComputeHelper {
if (num_output_dims == 0) {
// for left and right being both vector, output is scalar thus no shape
ORT_RETURN_IF_NOT(M_ == 1 && N_ == 1, "M_ == 1 && N_ == 1 was false");
ORT_RETURN_IF_NOT(K_ == right_shape[0],
"MatMul dimension mismatch. Left vector K (",
K_, ") != right vector K (", right_shape[0], ")");
} else {
if (left_num_dims == 1) {
ORT_RETURN_IF_NOT(num_dims_with_pad - 1 == num_output_dims, "num_dims_with_pad - 1 != num_output_dims");
Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/test/providers/cpu/math/matmul_integer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,5 +504,55 @@ TEST(MatmulIntegerOpTest, SharedPrepackedWeights) {
}
#endif

// Regression test: 1D vector dot product with matching K dimension should succeed.
// A=[K], B=[K] -> scalar output (dot product).
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_DotProduct) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {4}, {1, 2, 3, 4});
test.AddInput<uint8_t>("T2", {4}, {5, 6, 7, 8});
test.AddInput<uint8_t>("a_zero_point", {}, {0});
test.AddInput<uint8_t>("b_zero_point", {}, {0});
// dot product: 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
test.AddOutput<int32_t>("T3", {}, {70});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
}

// Same 1D vector dot product test with int8_t types.
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_DotProduct_int8) {
OpTester test("MatMulInteger", 10);
test.AddInput<int8_t>("T1", {3}, {1, -2, 3});
test.AddInput<int8_t>("T2", {3}, {4, 5, -6});
test.AddInput<int8_t>("a_zero_point", {}, {0});
test.AddInput<int8_t>("b_zero_point", {}, {0});
// dot product: 1*4 + (-2)*5 + 3*(-6) = 4 - 10 - 18 = -24
test.AddOutput<int32_t>("T3", {}, {-24});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
}

Comment thread
yuslepukhin marked this conversation as resolved.
// Regression test: 1D vectors with mismatched K dimension must fail safely.
// Covers prior invalid-shape handling for A=[K], B=[1] where K > 1.
TEST(MatmulIntegerOpTest, MatMulInteger_1D_Vector_KDimensionMismatch) {
OpTester test("MatMulInteger", 10);
test.AddShapeToTensorData(false);
test.AddInput<uint8_t>("T1", {4}, {1, 1, 1, 1});
test.AddInput<uint8_t>("T2", {1}, {5});
test.AddInput<uint8_t>("a_zero_point", {}, {0});
test.AddInput<uint8_t>("b_zero_point", {}, {0});
test.AddOutput<int32_t>("T3", {}, {0});
test.Run(OpTester::ExpectResult::kExpectFailure, "MatMul dimension mismatch");
}
Comment thread
yuslepukhin marked this conversation as resolved.

// Same regression test with int8_t types.
TEST(MatmulIntegerOpTest, MatMulInteger_int8_1D_Vector_KDimensionMismatch) {
OpTester test("MatMulInteger", 10);
test.AddShapeToTensorData(false);
test.AddInput<int8_t>("T1", {8}, {1, 1, 1, 1, 1, 1, 1, 1});
test.AddInput<int8_t>("T2", {1}, {5});
test.AddInput<int8_t>("a_zero_point", {}, {0});
test.AddInput<int8_t>("b_zero_point", {}, {0});
test.AddOutput<int32_t>("T3", {}, {0});
test.Run(OpTester::ExpectResult::kExpectFailure, "MatMul dimension mismatch");
}
Comment thread
yuslepukhin marked this conversation as resolved.

} // namespace test
} // namespace onnxruntime
Loading