Skip to content

Commit f38dec1

Browse files
authored
[Core] Add correctness tests for SpaceToDepth and MobileClip Attention fusion (#28168)
### Description Add correctness tests for fusions introduced in #27883 and #27747. The tests introduced in those PRs only check if fusion went through but not if the fused nodes produced semantically right results as the unfused subgraphs. Adding those tests to prevent accidental breakage in case something changed in the fused node's backing kernel implementation. ### Motivation and Context Adress test coverage gap
1 parent 480dd76 commit f38dec1

1 file changed

Lines changed: 101 additions & 92 deletions

File tree

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 101 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -5099,33 +5099,29 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionTest) {
50995099
builder.AddNode("Identity", {concat_out}, {output});
51005100
};
51015101

5102-
auto pre_graph_checker = [get_op_count](Graph& graph) {
5103-
const auto op_to_count = CountOpsInGraph(graph);
5104-
TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4);
5105-
TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1);
5106-
TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0);
5107-
return Status::OK();
5108-
};
5109-
5110-
auto post_graph_checker = [get_op_count](Graph& graph) {
5102+
auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
5103+
const Graph& graph = session.GetGraph();
51115104
const auto op_to_count = CountOpsInGraph(graph);
5112-
TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0);
5113-
TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0);
5114-
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1);
5105+
ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0);
5106+
ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0);
5107+
ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1);
51155108

51165109
for (const auto& node : graph.Nodes()) {
51175110
if (node.OpType() == "SpaceToDepth") {
51185111
const auto* blocksize_attr = graph_utils::GetNodeAttribute(node, "blocksize");
5119-
TEST_RETURN_IF_NOT(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2);
5112+
ASSERT_TRUE(blocksize_attr != nullptr && utils::HasInt(*blocksize_attr) && blocksize_attr->i() == 2);
51205113
}
51215114
}
5122-
5123-
return Status::OK();
51245115
};
51255116

5126-
auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
5127-
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1,
5128-
1, pre_graph_checker, post_graph_checker));
5117+
TransformerTester(build_test_case,
5118+
check_transformed_graph,
5119+
TransformerLevel::Default,
5120+
TransformerLevel::Level1,
5121+
13,
5122+
0.0,
5123+
0.0,
5124+
std::make_unique<SliceConcatToSpaceToDepthFusion>());
51295125
}
51305126

51315127
TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNodesTest) {
@@ -5178,26 +5174,22 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNode
51785174
builder.AddNode("Identity", {concat_out}, {output});
51795175
};
51805176

5181-
auto pre_graph_checker = [get_op_count](Graph& graph) {
5182-
const auto op_to_count = CountOpsInGraph(graph);
5183-
TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4);
5184-
TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1);
5185-
TEST_RETURN_IF_NOT(op_to_count.at("Constant") == 7);
5186-
TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0);
5187-
return Status::OK();
5188-
};
5189-
5190-
auto post_graph_checker = [get_op_count](Graph& graph) {
5177+
auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
5178+
const Graph& graph = session.GetGraph();
51915179
const auto op_to_count = CountOpsInGraph(graph);
5192-
TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0);
5193-
TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0);
5194-
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1);
5195-
return Status::OK();
5180+
ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0);
5181+
ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0);
5182+
ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1);
51965183
};
51975184

5198-
auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
5199-
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1,
5200-
1, pre_graph_checker, post_graph_checker));
5185+
TransformerTester(build_test_case,
5186+
check_transformed_graph,
5187+
TransformerLevel::Default,
5188+
TransformerLevel::Level1,
5189+
13,
5190+
0.0,
5191+
0.0,
5192+
std::make_unique<SliceConcatToSpaceToDepthFusion>());
52015193
}
52025194

52035195
TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBlockOrderTest) {
@@ -5234,27 +5226,23 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBloc
52345226
builder.AddNode("Identity", {concat_out}, {output});
52355227
};
52365228

5237-
auto pre_graph_checker = [get_op_count](Graph& graph) {
5229+
auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
5230+
const Graph& graph = session.GetGraph();
52385231
const auto op_to_count = CountOpsInGraph(graph);
5239-
TEST_RETURN_IF_NOT(op_to_count.at("Slice") == 4);
5240-
TEST_RETURN_IF_NOT(op_to_count.at("Concat") == 1);
5241-
TEST_RETURN_IF(get_op_count(op_to_count, "SpaceToDepth") != 0);
5242-
TEST_RETURN_IF(get_op_count(op_to_count, "Gather") != 0);
5243-
return Status::OK();
5244-
};
5245-
5246-
auto post_graph_checker = [get_op_count](Graph& graph) {
5247-
const auto op_to_count = CountOpsInGraph(graph);
5248-
TEST_RETURN_IF(op_to_count.count("Slice") != 0 && op_to_count.at("Slice") != 0);
5249-
TEST_RETURN_IF(op_to_count.count("Concat") != 0 && op_to_count.at("Concat") != 0);
5250-
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "SpaceToDepth") == 1);
5251-
TEST_RETURN_IF_NOT(get_op_count(op_to_count, "Gather") == 1);
5252-
return Status::OK();
5232+
ASSERT_TRUE(op_to_count.count("Slice") == 0 || op_to_count.at("Slice") == 0);
5233+
ASSERT_TRUE(op_to_count.count("Concat") == 0 || op_to_count.at("Concat") == 0);
5234+
ASSERT_EQ(get_op_count(op_to_count, "SpaceToDepth"), 1);
5235+
ASSERT_EQ(get_op_count(op_to_count, "Gather"), 1);
52535236
};
52545237

5255-
auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
5256-
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 13, *logger_, std::move(transformer), TransformerLevel::Level1,
5257-
1, pre_graph_checker, post_graph_checker));
5238+
TransformerTester(build_test_case,
5239+
check_transformed_graph,
5240+
TransformerLevel::Default,
5241+
TransformerLevel::Level1,
5242+
13,
5243+
0.0,
5244+
0.0,
5245+
std::make_unique<SliceConcatToSpaceToDepthFusion>());
52585246
}
52595247

52605248
TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForDynamicChannelPermutedBlockOrderTest) {
@@ -6107,7 +6095,7 @@ static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder,
61076095
builder.AddNode("Add", std::vector<NodeArg*>{input_skip, layer_scale_out}, std::vector<NodeArg*>{output});
61086096
}
61096097

6110-
static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
6098+
static Status CheckMobileClipAttentionFusedGraph(const Graph& graph) {
61116099
auto op_to_count = CountOpsInGraph(graph);
61126100
TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 1);
61136101
TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1);
@@ -6116,14 +6104,13 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
61166104
TEST_RETURN_IF_NOT(op_to_count["Split"] == 1);
61176105
TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1);
61186106
TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 2);
6119-
TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4);
61206107
TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1);
61216108
TEST_RETURN_IF_NOT(op_to_count["Add"] == 1);
61226109

61236110
int mha_nodes = 0;
61246111
int gemm_nodes = 0;
61256112
int split_nodes = 0;
6126-
for (Node& node : graph.Nodes()) {
6113+
for (const Node& node : graph.Nodes()) {
61276114
if (node.OpType() == "MultiHeadAttention" && node.Domain() == kMSDomain) {
61286115
++mha_nodes;
61296116
TEST_RETURN_IF_NOT(node.GetAttributes().at("num_heads").i() == 16);
@@ -6170,16 +6157,24 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
61706157
return Status::OK();
61716158
}
61726159

6173-
static Status CheckMobileClipAttentionFusedGraphOnProvider(Graph& graph, const char* provider) {
6160+
static Status CheckMobileClipAttentionFusedGraphOnProvider(const Graph& graph, const char* provider) {
61746161
ORT_RETURN_IF_ERROR(CheckMobileClipAttentionFusedGraph(graph));
61756162

6176-
for (Node& node : graph.Nodes()) {
6163+
for (const Node& node : graph.Nodes()) {
61776164
TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == provider);
61786165
}
61796166

61806167
return Status::OK();
61816168
}
61826169

6170+
static void CheckMobileClipAttentionFusedSession(InferenceSessionWrapper& session) {
6171+
ASSERT_STATUS_OK(CheckMobileClipAttentionFusedGraph(session.GetGraph()));
6172+
}
6173+
6174+
static void CheckMobileClipAttentionFusedCudaSession(InferenceSessionWrapper& session) {
6175+
ASSERT_STATUS_OK(CheckMobileClipAttentionFusedGraphOnProvider(session.GetGraph(), kCudaExecutionProvider));
6176+
}
6177+
61836178
static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph(Graph& graph) {
61846179
auto op_to_count = CountOpsInGraph(graph);
61856180
TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0);
@@ -6230,61 +6225,75 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) {
62306225
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd);
62316226
};
62326227

6233-
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(),
6234-
TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph));
6228+
TransformerTester(build_test_case,
6229+
CheckMobileClipAttentionFusedSession,
6230+
TransformerLevel::Level1,
6231+
TransformerLevel::Level2,
6232+
14,
6233+
1e-3,
6234+
0.0,
6235+
std::make_unique<AttentionFusion>());
62356236
}
62366237

62376238
TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) {
62386239
auto build_test_case = [](ModelTestBuilder& builder) {
62396240
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes);
62406241
};
62416242

6242-
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(),
6243-
TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph));
6243+
TransformerTester(build_test_case,
6244+
CheckMobileClipAttentionFusedSession,
6245+
TransformerLevel::Level1,
6246+
TransformerLevel::Level2,
6247+
14,
6248+
1e-3,
6249+
0.0,
6250+
std::make_unique<AttentionFusion>());
62446251
}
62456252

62466253
TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) {
6254+
auto cuda_ep = DefaultCudaExecutionProvider();
6255+
if (!cuda_ep) {
6256+
GTEST_SKIP() << "CUDA execution provider is not available";
6257+
}
6258+
62476259
auto build_test_case = [](ModelTestBuilder& builder) {
62486260
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd);
62496261
};
62506262

6251-
auto pre_graph_checker = [](Graph& graph) {
6252-
for (Node& node : graph.Nodes()) {
6253-
node.SetExecutionProviderType(kCudaExecutionProvider);
6254-
}
6255-
6256-
return Status::OK();
6257-
};
6258-
6259-
auto post_graph_checker = [](Graph& graph) {
6260-
return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider);
6261-
};
6262-
6263-
ASSERT_STATUS_OK(TestGraphTransformer(
6264-
build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
6265-
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
6263+
TransformerTester(build_test_case,
6264+
CheckMobileClipAttentionFusedCudaSession,
6265+
TransformerLevel::Level1,
6266+
TransformerLevel::Level2,
6267+
14,
6268+
1e-3,
6269+
0.0,
6270+
std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
6271+
{},
6272+
{},
6273+
std::move(cuda_ep));
62666274
}
62676275

62686276
TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) {
6277+
auto cuda_ep = DefaultCudaExecutionProvider();
6278+
if (!cuda_ep) {
6279+
GTEST_SKIP() << "CUDA execution provider is not available";
6280+
}
6281+
62696282
auto build_test_case = [](ModelTestBuilder& builder) {
62706283
BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes);
62716284
};
62726285

6273-
auto pre_graph_checker = [](Graph& graph) {
6274-
for (Node& node : graph.Nodes()) {
6275-
node.SetExecutionProviderType(kCudaExecutionProvider);
6276-
}
6277-
6278-
return Status::OK();
6279-
};
6280-
6281-
auto post_graph_checker = [](Graph& graph) {
6282-
return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider);
6283-
};
6284-
6285-
ASSERT_STATUS_OK(TestGraphTransformer(
6286-
build_test_case, 14, *logger_, std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
6287-
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
6286+
TransformerTester(build_test_case,
6287+
CheckMobileClipAttentionFusedCudaSession,
6288+
TransformerLevel::Level1,
6289+
TransformerLevel::Level2,
6290+
14,
6291+
1e-3,
6292+
0.0,
6293+
std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider}),
6294+
{},
6295+
{},
6296+
std::move(cuda_ep));
62886297
}
62896298

62906299
TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) {

0 commit comments

Comments
 (0)