1717#include " ck/library/utility/convolution_parameter.hpp"
1818#include " ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
1919#include " ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
20+ #include " ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
21+ #include " ck_tile/host/hip_check_error.hpp"
2022
2123using ::ck::DeviceMem;
2224using ::ck::HostTensorDescriptor;
2325using ::ck::Tensor;
2426
27+ template <typename DataType, typename GemmType = DataType>
28+ inline __host__ __device__ constexpr double get_rtol ()
29+ {
30+ if constexpr (std::is_same_v<DataType, float > && std::is_same_v<GemmType, ck::tf32_t >)
31+ return 5e-3 ;
32+ else if constexpr (std::is_same_v<DataType, float >)
33+ return 1e-3 ;
34+ else if constexpr (std::is_same_v<DataType, double >)
35+ return 1e-6 ;
36+ else if constexpr (std::is_same_v<DataType, ck::half_t >)
37+ return 1e-3 ;
38+ else if constexpr (std::is_same_v<DataType, ck::bhalf_t >)
39+ return 5e-2 ;
40+ else if constexpr (std::is_same_v<DataType, ck::f8_t >)
41+ return 1e-1 ;
42+ else if constexpr (std::is_same_v<DataType, ck::bf8_t >)
43+ return 1.5e-1 ;
44+ else
45+ return 1e-3 ;
46+ }
47+
48+ template <typename DataType, typename GemmType = DataType>
49+ inline __host__ __device__ constexpr double get_atol ()
50+ {
51+ if constexpr (std::is_same_v<DataType, float > && std::is_same_v<GemmType, ck::tf32_t >)
52+ return 1e-3 ;
53+ else if constexpr (std::is_same_v<DataType, float >)
54+ return 1e-3 ;
55+ else if constexpr (std::is_same_v<DataType, double >)
56+ return 1e-6 ;
57+ else if constexpr (std::is_same_v<DataType, ck::half_t >)
58+ return 1e-3 ;
59+ else if constexpr (std::is_same_v<DataType, ck::bhalf_t >)
60+ return 5e-2 ;
61+ else if constexpr (std::is_same_v<DataType, ck::f8_t >)
62+ return 16.1 ;
63+ else if constexpr (std::is_same_v<DataType, ck::bf8_t >)
64+ return 16.1 ;
65+ else
66+ return 1e-3 ;
67+ }
68+
2569void print_helper_msg ()
2670{
27- std::cout << " arg1: verification (0=no, 1=yes )\n "
71+ std::cout << " arg1: verification (0=no, 1=CPU, 2=GPU )\n "
2872 << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n "
2973 << " arg3: time kernel (0=no, 1=yes)\n "
3074 << ck::utils::conv::get_conv_param_parser_helper_msg () << std::endl;
@@ -38,7 +82,7 @@ template <ck::index_t NDimSpatial,
3882 typename WeiElementOp,
3983 typename OutElementOp,
4084 typename DeviceConvNdBwdDataInstance>
41- int run_conv_bwd_data (bool do_verification,
85+ int run_conv_bwd_data (int do_verification,
4286 int init_method,
4387 bool time_kernel,
4488 const ck::utils::conv::ConvParam& conv_param,
@@ -128,26 +172,30 @@ int run_conv_bwd_data(bool do_verification,
128172 wei_element_op,
129173 out_element_op);
130174
175+ // Check if optimized kernel supports these parameters
131176 if (!conv.IsSupportedArgument (argument.get ()))
132177 {
133178 std::cout << " Not support,please check parameters or device" ;
134179 return 0 ;
135180 }
136181
182+ // Run optimized kernel
137183 float ave_time = invoker.Run (argument.get (), StreamConfig{nullptr , time_kernel});
138184
139185 std::size_t flop = conv_param.GetFlops ();
140186 std::size_t num_btype = conv_param.GetByte <InDataType, WeiDataType, OutDataType>();
141187
142- float tflops = static_cast <float >(flop) / 1 .E9 / ave_time;
143-
188+ float tflops = static_cast <float >(flop) / 1 .E9 / ave_time;
144189 float gb_per_sec = num_btype / 1 .E6 / ave_time;
145190
146191 std::cout << " Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
147192 << std::endl;
148193
149- if (do_verification)
194+ std::cout << " do_verification = " << do_verification << std::endl;
195+
196+ if (do_verification == 1 )
150197 {
198+ // CPU verification
151199 auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
152200 InDataType,
153201 WeiDataType,
@@ -175,6 +223,56 @@ int run_conv_bwd_data(bool do_verification,
175223
176224 return ck::utils::check_err (in_device, in_host) ? 0 : 1 ;
177225 }
226+ else if (do_verification == 2 )
227+ {
228+ // GPU verification
229+ std::cout << " Running GPU verification..." << std::endl;
230+
231+ DeviceMem in_device_ref_buf (sizeof (InDataType) * in_device.mDesc .GetElementSpaceSize ());
232+ in_device_ref_buf.SetZero ();
233+
234+ // Extract dimensions using helper function
235+ ck::ref::ConvDims dims = ck::utils::conv::extract_conv_dims (conv_param, NDimSpatial);
236+
237+ constexpr ck::index_t block_size = 256 ;
238+ const ck::long_index_t input_length = dims.N * dims.Di * dims.Hi * dims.Wi * dims.C ;
239+ const ck::index_t grid_size = (input_length + block_size - 1 ) / block_size;
240+
241+ auto gpu_ref_kernel = ck::ref::naive_conv_bwd_data_ndhwc_kzyxc_ndhwk<InDataType,
242+ WeiDataType,
243+ OutDataType,
244+ float ,
245+ InElementOp,
246+ WeiElementOp,
247+ OutElementOp>;
248+
249+ gpu_ref_kernel<<<dim3 (grid_size), dim3 (block_size), 0 , nullptr >>>(
250+ reinterpret_cast <InDataType*>(in_device_ref_buf.GetDeviceBuffer ()),
251+ reinterpret_cast <const WeiDataType*>(wei_device_buf.GetDeviceBuffer ()),
252+ reinterpret_cast <const OutDataType*>(out_device_buf.GetDeviceBuffer ()),
253+ dims);
254+
255+ HIP_CHECK_ERROR (hipDeviceSynchronize ());
256+
257+ std::cout << " GPU reference kernel completed, copying results..." << std::endl;
258+
259+ // Copy GPU reference result
260+ Tensor<InDataType> in_gpu_ref (in_host.mDesc );
261+ in_device_ref_buf.FromDevice (in_gpu_ref.mData .data ());
262+
263+ // Copy optimized kernel result
264+ in_device_buf.FromDevice (in_device.mData .data ());
265+
266+ // Compare: Optimized kernel result vs GPU reference result
267+ bool pass = ck::utils::check_err (in_device,
268+ in_gpu_ref,
269+ " Error: Incorrect results!" ,
270+ get_rtol<InDataType, float >(),
271+ get_atol<InDataType, float >());
272+ std::cout << " GPU verification result is:" << (pass ? " correct" : " fail" ) << std::endl;
273+
274+ return pass ? 0 : 1 ;
275+ }
178276
179277 return 0 ;
180278}
0 commit comments