Skip to content

Commit 365f7c5

Browse files
committed
Add instances
1 parent 554497c commit 365f7c5

7 files changed

+425
-1
lines changed

library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace instance {
1919

2020
using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>;
2121

22+
#if defined(CK_USE_XDL)
2223
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
2324
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
2425
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
@@ -27,6 +28,18 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
2728
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
2829
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
2930
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
31+
#endif // CK_USE_XDL
32+
33+
#if defined(CK_USE_WMMA)
34+
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
35+
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
36+
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
37+
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
38+
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
39+
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
40+
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
41+
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
42+
#endif // CK_USE_WMMA
3043

3144
template <typename ADataType,
3245
typename BDataType,
@@ -45,33 +58,61 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
4558
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
4659
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
4760
{
61+
#if defined(CK_USE_XDL)
4862
ck::tensor_operation::device::instance::
4963
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
5064
op_ptrs);
65+
#endif
66+
#if defined(CK_USE_WMMA)
67+
ck::tensor_operation::device::instance::
68+
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
69+
op_ptrs);
70+
#endif
5171
}
5272
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
5373
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
5474
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
5575
{
76+
#if defined(CK_USE_XDL)
5677
ck::tensor_operation::device::instance::
5778
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
5879
op_ptrs);
80+
#endif
81+
#if defined(CK_USE_WMMA)
82+
ck::tensor_operation::device::instance::
83+
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
84+
op_ptrs);
85+
#endif
5986
}
6087
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
6188
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
6289
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
6390
{
91+
#if defined(CK_USE_XDL)
6492
ck::tensor_operation::device::instance::
6593
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
6694
op_ptrs);
95+
#endif
96+
#if defined(CK_USE_WMMA)
97+
ck::tensor_operation::device::instance::
98+
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
99+
op_ptrs);
100+
#endif
67101
}
68102
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
69103
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
70104
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
71105
{
106+
#if defined(CK_USE_XDL)
72107
ck::tensor_operation::device::instance::
73108
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
74109
op_ptrs);
110+
#endif
111+
#if defined(CK_USE_WMMA)
112+
ck::tensor_operation::device::instance::
113+
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
114+
op_ptrs);
115+
#endif
75116
}
76117
}
77118

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
# ONLY XDL_KERNELS
1+
# ONLY XDL_AND_WMMA_KERNELS
22
add_instance_library(device_gemm_bias_add_reduce_instance
33
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
44
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
55
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
66
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
7+
8+
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
9+
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
10+
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
11+
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
712
)

0 commit comments

Comments
 (0)