@@ -19,6 +19,7 @@ namespace instance {
1919
2020using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1 , 2 >;
2121
22+ #if defined(CK_USE_XDL)
2223void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances (
2324 std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
2425void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances (
@@ -27,6 +28,18 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
2728 std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
2829void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances (
2930 std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
31+ #endif // CK_USE_XDL
32+
33+ #if defined(CK_USE_WMMA)
34+ void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances (
35+ std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
36+ void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances (
37+ std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
38+ void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances (
39+ std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
40+ void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances (
41+ std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
42+ #endif // CK_USE_WMMA
3043
3144template <typename ADataType,
3245 typename BDataType,
@@ -45,33 +58,61 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
4558 is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
4659 is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
4760 {
61+ #if defined(CK_USE_XDL)
4862 ck::tensor_operation::device::instance::
4963 add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances (
5064 op_ptrs);
65+ #endif
66+ #if defined(CK_USE_WMMA)
67+ ck::tensor_operation::device::instance::
68+ add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances (
69+ op_ptrs);
70+ #endif
5171 }
5272 else if constexpr (is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
5373 is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
5474 is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
5575 {
76+ #if defined(CK_USE_XDL)
5677 ck::tensor_operation::device::instance::
5778 add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances (
5879 op_ptrs);
80+ #endif
81+ #if defined(CK_USE_WMMA)
82+ ck::tensor_operation::device::instance::
83+ add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances (
84+ op_ptrs);
85+ #endif
5986 }
6087 else if constexpr (is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
6188 is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
6289 is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
6390 {
91+ #if defined(CK_USE_XDL)
6492 ck::tensor_operation::device::instance::
6593 add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances (
6694 op_ptrs);
95+ #endif
96+ #if defined(CK_USE_WMMA)
97+ ck::tensor_operation::device::instance::
98+ add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances (
99+ op_ptrs);
100+ #endif
67101 }
68102 else if constexpr (is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
69103 is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
70104 is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
71105 {
106+ #if defined(CK_USE_XDL)
72107 ck::tensor_operation::device::instance::
73108 add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances (
74109 op_ptrs);
110+ #endif
111+ #if defined(CK_USE_WMMA)
112+ ck::tensor_operation::device::instance::
113+ add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances (
114+ op_ptrs);
115+ #endif
75116 }
76117 }
77118
0 commit comments