@@ -5099,33 +5099,29 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionTest) {
50995099 builder.AddNode (" Identity" , {concat_out}, {output});
51005100 };
51015101
5102- auto pre_graph_checker = [get_op_count](Graph& graph) {
5103- const auto op_to_count = CountOpsInGraph (graph);
5104- TEST_RETURN_IF_NOT (op_to_count.at (" Slice" ) == 4 );
5105- TEST_RETURN_IF_NOT (op_to_count.at (" Concat" ) == 1 );
5106- TEST_RETURN_IF (get_op_count (op_to_count, " SpaceToDepth" ) != 0 );
5107- return Status::OK ();
5108- };
5109-
5110- auto post_graph_checker = [get_op_count](Graph& graph) {
5102+ auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
5103+ const Graph& graph = session.GetGraph ();
51115104 const auto op_to_count = CountOpsInGraph (graph);
5112- TEST_RETURN_IF (op_to_count.count (" Slice" ) != 0 && op_to_count.at (" Slice" ) ! = 0 );
5113- TEST_RETURN_IF (op_to_count.count (" Concat" ) != 0 && op_to_count.at (" Concat" ) ! = 0 );
5114- TEST_RETURN_IF_NOT (get_op_count (op_to_count, " SpaceToDepth" ) == 1 );
5105+ ASSERT_TRUE (op_to_count.count (" Slice" ) == 0 || op_to_count.at (" Slice" ) = = 0 );
5106+ ASSERT_TRUE (op_to_count.count (" Concat" ) == 0 || op_to_count.at (" Concat" ) = = 0 );
5107+ ASSERT_EQ (get_op_count (op_to_count, " SpaceToDepth" ), 1 );
51155108
51165109 for (const auto & node : graph.Nodes ()) {
51175110 if (node.OpType () == " SpaceToDepth" ) {
51185111 const auto * blocksize_attr = graph_utils::GetNodeAttribute (node, " blocksize" );
5119- TEST_RETURN_IF_NOT (blocksize_attr != nullptr && utils::HasInt (*blocksize_attr) && blocksize_attr->i () == 2 );
5112+ ASSERT_TRUE (blocksize_attr != nullptr && utils::HasInt (*blocksize_attr) && blocksize_attr->i () == 2 );
51205113 }
51215114 }
5122-
5123- return Status::OK ();
51245115 };
51255116
5126- auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
5127- ASSERT_STATUS_OK (TestGraphTransformer (build_test_case, 13 , *logger_, std::move (transformer), TransformerLevel::Level1,
5128- 1 , pre_graph_checker, post_graph_checker));
5117+ TransformerTester (build_test_case,
5118+ check_transformed_graph,
5119+ TransformerLevel::Default,
5120+ TransformerLevel::Level1,
5121+ 13 ,
5122+ 0.0 ,
5123+ 0.0 ,
5124+ std::make_unique<SliceConcatToSpaceToDepthFusion>());
51295125}
51305126
51315127TEST_F (GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNodesTest) {
@@ -5178,26 +5174,22 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithConstantNode
51785174 builder.AddNode (" Identity" , {concat_out}, {output});
51795175 };
51805176
5181- auto pre_graph_checker = [get_op_count](Graph& graph) {
5182- const auto op_to_count = CountOpsInGraph (graph);
5183- TEST_RETURN_IF_NOT (op_to_count.at (" Slice" ) == 4 );
5184- TEST_RETURN_IF_NOT (op_to_count.at (" Concat" ) == 1 );
5185- TEST_RETURN_IF_NOT (op_to_count.at (" Constant" ) == 7 );
5186- TEST_RETURN_IF (get_op_count (op_to_count, " SpaceToDepth" ) != 0 );
5187- return Status::OK ();
5188- };
5189-
5190- auto post_graph_checker = [get_op_count](Graph& graph) {
5177+ auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
5178+ const Graph& graph = session.GetGraph ();
51915179 const auto op_to_count = CountOpsInGraph (graph);
5192- TEST_RETURN_IF (op_to_count.count (" Slice" ) != 0 && op_to_count.at (" Slice" ) != 0 );
5193- TEST_RETURN_IF (op_to_count.count (" Concat" ) != 0 && op_to_count.at (" Concat" ) != 0 );
5194- TEST_RETURN_IF_NOT (get_op_count (op_to_count, " SpaceToDepth" ) == 1 );
5195- return Status::OK ();
5180+ ASSERT_TRUE (op_to_count.count (" Slice" ) == 0 || op_to_count.at (" Slice" ) == 0 );
5181+ ASSERT_TRUE (op_to_count.count (" Concat" ) == 0 || op_to_count.at (" Concat" ) == 0 );
5182+ ASSERT_EQ (get_op_count (op_to_count, " SpaceToDepth" ), 1 );
51965183 };
51975184
5198- auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
5199- ASSERT_STATUS_OK (TestGraphTransformer (build_test_case, 13 , *logger_, std::move (transformer), TransformerLevel::Level1,
5200- 1 , pre_graph_checker, post_graph_checker));
5185+ TransformerTester (build_test_case,
5186+ check_transformed_graph,
5187+ TransformerLevel::Default,
5188+ TransformerLevel::Level1,
5189+ 13 ,
5190+ 0.0 ,
5191+ 0.0 ,
5192+ std::make_unique<SliceConcatToSpaceToDepthFusion>());
52015193}
52025194
52035195TEST_F (GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBlockOrderTest) {
@@ -5234,27 +5226,23 @@ TEST_F(GraphTransformationTests, SliceConcatToSpaceToDepthFusionWithPermutedBloc
52345226 builder.AddNode (" Identity" , {concat_out}, {output});
52355227 };
52365228
5237- auto pre_graph_checker = [get_op_count](Graph& graph) {
5229+ auto check_transformed_graph = [get_op_count](InferenceSessionWrapper& session) {
5230+ const Graph& graph = session.GetGraph ();
52385231 const auto op_to_count = CountOpsInGraph (graph);
5239- TEST_RETURN_IF_NOT (op_to_count.at (" Slice" ) == 4 );
5240- TEST_RETURN_IF_NOT (op_to_count.at (" Concat" ) == 1 );
5241- TEST_RETURN_IF (get_op_count (op_to_count, " SpaceToDepth" ) != 0 );
5242- TEST_RETURN_IF (get_op_count (op_to_count, " Gather" ) != 0 );
5243- return Status::OK ();
5244- };
5245-
5246- auto post_graph_checker = [get_op_count](Graph& graph) {
5247- const auto op_to_count = CountOpsInGraph (graph);
5248- TEST_RETURN_IF (op_to_count.count (" Slice" ) != 0 && op_to_count.at (" Slice" ) != 0 );
5249- TEST_RETURN_IF (op_to_count.count (" Concat" ) != 0 && op_to_count.at (" Concat" ) != 0 );
5250- TEST_RETURN_IF_NOT (get_op_count (op_to_count, " SpaceToDepth" ) == 1 );
5251- TEST_RETURN_IF_NOT (get_op_count (op_to_count, " Gather" ) == 1 );
5252- return Status::OK ();
5232+ ASSERT_TRUE (op_to_count.count (" Slice" ) == 0 || op_to_count.at (" Slice" ) == 0 );
5233+ ASSERT_TRUE (op_to_count.count (" Concat" ) == 0 || op_to_count.at (" Concat" ) == 0 );
5234+ ASSERT_EQ (get_op_count (op_to_count, " SpaceToDepth" ), 1 );
5235+ ASSERT_EQ (get_op_count (op_to_count, " Gather" ), 1 );
52535236 };
52545237
5255- auto transformer = std::make_unique<SliceConcatToSpaceToDepthFusion>();
5256- ASSERT_STATUS_OK (TestGraphTransformer (build_test_case, 13 , *logger_, std::move (transformer), TransformerLevel::Level1,
5257- 1 , pre_graph_checker, post_graph_checker));
5238+ TransformerTester (build_test_case,
5239+ check_transformed_graph,
5240+ TransformerLevel::Default,
5241+ TransformerLevel::Level1,
5242+ 13 ,
5243+ 0.0 ,
5244+ 0.0 ,
5245+ std::make_unique<SliceConcatToSpaceToDepthFusion>());
52585246}
52595247
52605248TEST_F (GraphTransformationTests, SliceConcatToSpaceToDepthFusionNotTriggeredForDynamicChannelPermutedBlockOrderTest) {
@@ -6107,7 +6095,7 @@ static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder,
61076095 builder.AddNode (" Add" , std::vector<NodeArg*>{input_skip, layer_scale_out}, std::vector<NodeArg*>{output});
61086096}
61096097
6110- static Status CheckMobileClipAttentionFusedGraph (Graph& graph) {
6098+ static Status CheckMobileClipAttentionFusedGraph (const Graph& graph) {
61116099 auto op_to_count = CountOpsInGraph (graph);
61126100 TEST_RETURN_IF_NOT (op_to_count[" com.microsoft.MultiHeadAttention" ] == 1 );
61136101 TEST_RETURN_IF_NOT (op_to_count[" Gemm" ] == 1 );
@@ -6116,14 +6104,13 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
61166104 TEST_RETURN_IF_NOT (op_to_count[" Split" ] == 1 );
61176105 TEST_RETURN_IF_NOT (op_to_count[" MatMul" ] == 1 );
61186106 TEST_RETURN_IF_NOT (op_to_count[" Transpose" ] == 2 );
6119- TEST_RETURN_IF_NOT (op_to_count[" Reshape" ] == 4 );
61206107 TEST_RETURN_IF_NOT (op_to_count[" Mul" ] == 1 );
61216108 TEST_RETURN_IF_NOT (op_to_count[" Add" ] == 1 );
61226109
61236110 int mha_nodes = 0 ;
61246111 int gemm_nodes = 0 ;
61256112 int split_nodes = 0 ;
6126- for (Node& node : graph.Nodes ()) {
6113+ for (const Node& node : graph.Nodes ()) {
61276114 if (node.OpType () == " MultiHeadAttention" && node.Domain () == kMSDomain ) {
61286115 ++mha_nodes;
61296116 TEST_RETURN_IF_NOT (node.GetAttributes ().at (" num_heads" ).i () == 16 );
@@ -6170,16 +6157,24 @@ static Status CheckMobileClipAttentionFusedGraph(Graph& graph) {
61706157 return Status::OK ();
61716158}
61726159
6173- static Status CheckMobileClipAttentionFusedGraphOnProvider (Graph& graph, const char * provider) {
6160+ static Status CheckMobileClipAttentionFusedGraphOnProvider (const Graph& graph, const char * provider) {
61746161 ORT_RETURN_IF_ERROR (CheckMobileClipAttentionFusedGraph (graph));
61756162
6176- for (Node& node : graph.Nodes ()) {
6163+ for (const Node& node : graph.Nodes ()) {
61776164 TEST_RETURN_IF_NOT (node.GetExecutionProviderType () == provider);
61786165 }
61796166
61806167 return Status::OK ();
61816168}
61826169
6170+ static void CheckMobileClipAttentionFusedSession (InferenceSessionWrapper& session) {
6171+ ASSERT_STATUS_OK (CheckMobileClipAttentionFusedGraph (session.GetGraph ()));
6172+ }
6173+
6174+ static void CheckMobileClipAttentionFusedCudaSession (InferenceSessionWrapper& session) {
6175+ ASSERT_STATUS_OK (CheckMobileClipAttentionFusedGraphOnProvider (session.GetGraph (), kCudaExecutionProvider ));
6176+ }
6177+
61836178static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph (Graph& graph) {
61846179 auto op_to_count = CountOpsInGraph (graph);
61856180 TEST_RETURN_IF_NOT (op_to_count[" com.microsoft.MultiHeadAttention" ] == 0 );
@@ -6230,61 +6225,75 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) {
62306225 BuildMobileClipAttentionTestCase (builder, MobileClipProjectionType::MatMulAdd);
62316226 };
62326227
6233- ASSERT_STATUS_OK (TestGraphTransformer (build_test_case, 14 , *logger_, std::make_unique<AttentionFusion>(),
6234- TransformerLevel::Level2, 1 , nullptr , CheckMobileClipAttentionFusedGraph));
6228+ TransformerTester (build_test_case,
6229+ CheckMobileClipAttentionFusedSession,
6230+ TransformerLevel::Level1,
6231+ TransformerLevel::Level2,
6232+ 14 ,
6233+ 1e-3 ,
6234+ 0.0 ,
6235+ std::make_unique<AttentionFusion>());
62356236}
62366237
62376238TEST_F (GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) {
62386239 auto build_test_case = [](ModelTestBuilder& builder) {
62396240 BuildMobileClipAttentionTestCase (builder, MobileClipProjectionType::GemmWithReshapes);
62406241 };
62416242
6242- ASSERT_STATUS_OK (TestGraphTransformer (build_test_case, 14 , *logger_, std::make_unique<AttentionFusion>(),
6243- TransformerLevel::Level2, 1 , nullptr , CheckMobileClipAttentionFusedGraph));
6243+ TransformerTester (build_test_case,
6244+ CheckMobileClipAttentionFusedSession,
6245+ TransformerLevel::Level1,
6246+ TransformerLevel::Level2,
6247+ 14 ,
6248+ 1e-3 ,
6249+ 0.0 ,
6250+ std::make_unique<AttentionFusion>());
62446251}
62456252
62466253TEST_F (GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) {
6254+ auto cuda_ep = DefaultCudaExecutionProvider ();
6255+ if (!cuda_ep) {
6256+ GTEST_SKIP () << " CUDA execution provider is not available" ;
6257+ }
6258+
62476259 auto build_test_case = [](ModelTestBuilder& builder) {
62486260 BuildMobileClipAttentionTestCase (builder, MobileClipProjectionType::MatMulAdd);
62496261 };
62506262
6251- auto pre_graph_checker = [](Graph& graph) {
6252- for (Node& node : graph.Nodes ()) {
6253- node.SetExecutionProviderType (kCudaExecutionProvider );
6254- }
6255-
6256- return Status::OK ();
6257- };
6258-
6259- auto post_graph_checker = [](Graph& graph) {
6260- return CheckMobileClipAttentionFusedGraphOnProvider (graph, kCudaExecutionProvider );
6261- };
6262-
6263- ASSERT_STATUS_OK (TestGraphTransformer (
6264- build_test_case, 14 , *logger_, std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider }),
6265- TransformerLevel::Level2, 1 , pre_graph_checker, post_graph_checker));
6263+ TransformerTester (build_test_case,
6264+ CheckMobileClipAttentionFusedCudaSession,
6265+ TransformerLevel::Level1,
6266+ TransformerLevel::Level2,
6267+ 14 ,
6268+ 1e-3 ,
6269+ 0.0 ,
6270+ std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider }),
6271+ {},
6272+ {},
6273+ std::move (cuda_ep));
62666274}
62676275
62686276TEST_F (GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) {
6277+ auto cuda_ep = DefaultCudaExecutionProvider ();
6278+ if (!cuda_ep) {
6279+ GTEST_SKIP () << " CUDA execution provider is not available" ;
6280+ }
6281+
62696282 auto build_test_case = [](ModelTestBuilder& builder) {
62706283 BuildMobileClipAttentionTestCase (builder, MobileClipProjectionType::GemmWithReshapes);
62716284 };
62726285
6273- auto pre_graph_checker = [](Graph& graph) {
6274- for (Node& node : graph.Nodes ()) {
6275- node.SetExecutionProviderType (kCudaExecutionProvider );
6276- }
6277-
6278- return Status::OK ();
6279- };
6280-
6281- auto post_graph_checker = [](Graph& graph) {
6282- return CheckMobileClipAttentionFusedGraphOnProvider (graph, kCudaExecutionProvider );
6283- };
6284-
6285- ASSERT_STATUS_OK (TestGraphTransformer (
6286- build_test_case, 14 , *logger_, std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider }),
6287- TransformerLevel::Level2, 1 , pre_graph_checker, post_graph_checker));
6286+ TransformerTester (build_test_case,
6287+ CheckMobileClipAttentionFusedCudaSession,
6288+ TransformerLevel::Level1,
6289+ TransformerLevel::Level2,
6290+ 14 ,
6291+ 1e-3 ,
6292+ 0.0 ,
6293+ std::make_unique<AttentionFusion>(InlinedHashSet<std::string_view>{kCudaExecutionProvider }),
6294+ {},
6295+ {},
6296+ std::move (cuda_ep));
62886297}
62896298
62906299TEST_F (GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) {
0 commit comments