@@ -12,62 +12,38 @@ namespace onnxruntime {
1212namespace contrib {
1313namespace webgpu {
1414
15- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , Attention);
16- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasAdd);
17- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , CausalConvWithState);
18- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasGelu);
19- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasSplitGelu);
20- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , FastGelu);
21- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , FusedConv);
22- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , GatherBlockQuantized);
23- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , Gelu);
24- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , GroupQueryAttention);
25- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , LinearAttention);
26- // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
27- class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , 16 , LayerNormalization);
28- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , MatMulNBits);
29- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , MultiHeadAttention);
30- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , QuickGelu);
31- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , RotaryEmbedding);
32- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , SimplifiedLayerNormalization);
33- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , SkipLayerNormalization);
34- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , SimplifiedLayerNormalization);
35- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , SkipSimplifiedLayerNormalization);
36- // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MoE);
37- class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , QMoE);
38-
3915template <>
4016KernelCreateInfo BuildKernelCreateInfo<void >() {
4117 KernelCreateInfo info;
4218 return info;
4319}
4420
45- Status RegisterWebGpuContribKernels (KernelRegistry& kernel_registry, bool enable_graph_capture) {
46- static const BuildKernelCreateInfoFn function_table[] = {
47- BuildKernelCreateInfo<void >, // default entry to avoid the list become empty after ops-reducing
48- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , Attention)>,
49- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasAdd)>,
50- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , CausalConvWithState)>,
51- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasGelu)>,
52- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasSplitGelu)>,
53- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , FastGelu)>,
54- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , GatherBlockQuantized)>,
55- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , FusedConv)>,
56- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , Gelu)>,
57- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , LinearAttention)>,
58- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , MatMulNBits)>,
59- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , MultiHeadAttention)>,
60- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , QuickGelu)>,
61- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , RotaryEmbedding)>,
62- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , SkipLayerNormalization)>,
63- // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
64- BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , 16 , LayerNormalization)>,
65- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , SimplifiedLayerNormalization)>,
66- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , SkipSimplifiedLayerNormalization)>,
67- // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MoE)>,
68- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , QMoE)>};
21+ static const BuildKernelCreateInfoFn build_kernel_create_info_function_table[] = {
22+ BuildKernelCreateInfo<void >, // default entry to avoid the list become empty after ops-reducing
23+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , Attention)>,
24+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasAdd)>,
25+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , CausalConvWithState)>,
26+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasGelu)>,
27+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , BiasSplitGelu)>,
28+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , FastGelu)>,
29+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , GatherBlockQuantized)>,
30+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , FusedConv)>,
31+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , Gelu)>,
32+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , LinearAttention)>,
33+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , MatMulNBits)>,
34+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , MultiHeadAttention)>,
35+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , QuickGelu)>,
36+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , RotaryEmbedding)>,
37+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , SkipLayerNormalization)>,
38+ // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
39+ BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , 16 , LayerNormalization)>,
40+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kOnnxDomain , 1 , SimplifiedLayerNormalization)>,
41+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , SkipSimplifiedLayerNormalization)>,
42+ // BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MoE)>,
43+ BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME (kWebGpuExecutionProvider , kMSDomain , 1 , QMoE)>};
6944
70- for (auto & function_table_entry : function_table) {
45+ Status RegisterWebGpuContribKernels (KernelRegistry& kernel_registry, bool enable_graph_capture) {
46+ for (auto & function_table_entry : build_kernel_create_info_function_table) {
7147 KernelCreateInfo info = function_table_entry ();
7248 if (info.kernel_def != nullptr ) { // filter disabled entries where type is void
7349 ORT_RETURN_IF_ERROR (kernel_registry.Register (std::move (info)));
0 commit comments