diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 733c7568..f85d5fc7 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -559,7 +559,8 @@ int main(int argc, char **argv) { // gen rand gen_rand_vector(host_output, static_cast(n) * k * ho * wo, 0.0, 1.0); gen_rand_vector(host_weight, static_cast(k) * c * y * x, -0.5, 0.5); - gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number + //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number + memset(host_input, 0xfc, static_cast(k) * c * y * x * sizeof(float)); // gen_rand_vector(host_output, static_cast(n) * k * ho * wo,1, 1); // gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); #ifdef USE_GPU_NAIVE_CONV diff --git a/driver/igemm_bwd_gtc_driver.h b/driver/igemm_bwd_gtc_driver.h index 07faefd8..5be1ba83 100755 --- a/driver/igemm_bwd_gtc_driver.h +++ b/driver/igemm_bwd_gtc_driver.h @@ -267,6 +267,13 @@ class igemm_bwd_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); + + assert(c % group == 0 && k % group == 0); + + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; @@ -340,9 +347,17 @@ class igemm_bwd_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + if(splits == 0){ + printf("image size (c*h*w or k*h*w) is bigger than 4g, which is not supported now\n"); + return false; + } + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; @@ -456,9 +471,14 @@ class igemm_bwd_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; @@ -622,7 +642,7 @@ class igemm_bwd_gtc_t { hipEvent_t stop; hipEventCreate(&start); hipEventCreate(&stop); - HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1, + HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, splits, 1, u_block_size, 1, 1, 0, 0, NULL, (void **)&config, start, stop)); hipEventSynchronize(stop); @@ -670,7 +690,7 @@ class igemm_bwd_gtc_t { hipEventCreate(&start); hipEventCreate(&stop); // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config, start, stop)); hipEventSynchronize(stop); @@ -680,7 +700,7 @@ class igemm_bwd_gtc_t { #else gpu_timer_t timer(NULL); timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, + HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config)); timer.stop(); @@ -703,7 +723,7 @@ class igemm_bwd_gtc_t { hipEvent_t stop; hipEventCreate(&start); hipEventCreate(&stop); - HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1, + HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, splits, 1, u_block_size, 1, 1, 0, 0, NULL, (void **)&config, start, stop)); hipEventSynchronize(stop); @@ -732,7 +752,7 @@ class igemm_bwd_gtc_t { hipEventCreate(&start); hipEventCreate(&stop); // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config, start, stop)); hipEventSynchronize(stop); @@ -742,7 +762,7 @@ class igemm_bwd_gtc_t { #else gpu_timer_t timer(NULL); timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, + HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config)); timer.stop(); diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 4ce61352..56ff274b 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -141,8 +141,10 @@ class igemm_fwd_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); - int splits = split_batch_size(arg, tunable); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; @@ -160,54 +162,6 @@ class igemm_fwd_gtc_t { return grid_size; } - // this is to support big tensor > 4G. need to decide how many splits needed - // return the number of splits - int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable) - { - int hi = arg->get_int("in_h"); - int wi = arg->get_int("in_w"); - int n = arg->get_int("batchsize"); - int k = arg->get_int("out_channels"); - int c = arg->get_int("in_channels"); - - int stride_h = arg->get_int("conv_stride_h"); - int stride_w = arg->get_int("conv_stride_w"); - int dilation_h = arg->get_int("dilation_h"); - int dilation_w = arg->get_int("dilation_w"); - int pad_h = arg->get_int("pad_h"); - int pad_w = arg->get_int("pad_w"); - int y = arg->get_int("fil_h"); - int x = arg->get_int("fil_w"); - int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); - int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); - - int data_byte = utility_string_to_data_byte(tunable->precision); - size_t image_size_input = static_cast(c) * hi * wi * data_byte; - size_t image_size_output = static_cast(k) * ho * wo * data_byte; - size_t size_4g = 0xffffffffUL; - if(image_size_input >= size_4g || image_size_output >= size_4g) - return 0; - - size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; - size_t splited_n = size_4g / image_size; - - // round up splits, we must match - // 1. splited_n * image_size < size_4g - // 2. n % splited_n == 0 - // if(splited_n >= n) - // return 1; - assert(splited_n != 0); - while(splited_n >= 1){ - // printf("n:%d, splited_n:%d\n", n, splited_n); - if(n % splited_n == 0) - break; - splited_n--; - } - - assert(splited_n * image_size < size_4g && n % splited_n == 0); - return n / splited_n; - } - bool tunable_is_valid(const args_t *arg, const igemm_gtc_tunable_t *tunable) { @@ -228,15 +182,17 @@ class igemm_fwd_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); - int splits = split_batch_size(arg, tunable); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); if(splits == 0){ - printf("image size (c*h*w) is bigger than 4g, which is not supported now\n"); + printf("image size (c*h*w or k*h*w) is bigger than 4g, which is not supported now\n"); return false; } n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; @@ -328,10 +284,12 @@ class igemm_fwd_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); - int splits = split_batch_size(arg, tunable); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; diff --git a/driver/igemm_gtc_base.h b/driver/igemm_gtc_base.h index e8215f01..89b336bb 100755 --- a/driver/igemm_gtc_base.h +++ b/driver/igemm_gtc_base.h @@ -310,4 +310,35 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { return kernel_name; } +// this is to support big tensor > 4G. need to decide how many splits needed +// return the number of splits, valid for nchw, nhwc, 2d/3d conv +int igemm_split_batch_size(int n, int wi, int hi, int di, int c, int k, int wo, int ho, int do_, int data_byte) +{ + size_t image_size_input = static_cast(c) * di * hi * wi * data_byte; + size_t image_size_output = static_cast(k) * do_ * ho * wo * data_byte; + size_t size_4g = 0xffffffffUL; + if(image_size_input >= size_4g || image_size_output >= size_4g) + return 0; + + size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; + size_t splited_n = size_4g / image_size; + + // round up splits, we must match + // 1. splited_n * image_size < size_4g + // 2. n % splited_n == 0 + assert(splited_n != 0); + + if(splited_n >= n) + return 1; // speed up following while loop + + while(splited_n >= 1){ + if(n % splited_n == 0) + break; + splited_n--; + } + + assert(splited_n * image_size < size_4g && n % splited_n == 0); + return n / splited_n; +} + #endif \ No newline at end of file diff --git a/driver/igemm_wrw_gtc_driver.h b/driver/igemm_wrw_gtc_driver.h index e63efae4..7c9e23b8 100644 --- a/driver/igemm_wrw_gtc_driver.h +++ b/driver/igemm_wrw_gtc_driver.h @@ -196,6 +196,11 @@ class igemm_wrw_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); + + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; @@ -245,8 +250,16 @@ class igemm_wrw_gtc_t { int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); int b = tunable->nxe == 0 ? (ho * wo) : ((ho * wo + tunable->nxb - 1) / tunable->nxb) * tunable->nxb; // pad to nxb modulo when nxe != 0 + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + if(splits == 0){ + printf("image size (c*h*w or k*h*w) is bigger than 4g, which is not supported now\n"); + return false; + } + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; @@ -266,6 +279,12 @@ class igemm_wrw_gtc_t { int nxe = tunable->nxe == 0 ? 1 : tunable->nxe; bool unit_conv = (x==1)&&(y==1)&&(stride_h==1)&&(stride_w==1)&&(dilation_h==1)&&(dilation_w==1)&&(pad_h==0)&&(pad_w==0); + if(splits > 1 && gemm_k_global_split == 0) + { + // large tensor can only used for gkgs kernel + return false; + } + if(((c / group) % (gemm_n_per_block / nxe) != 0) || (((x * y) % nxe) != 0)) { return false; @@ -339,8 +358,13 @@ class igemm_wrw_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here + int b = tunable->nxe == 0 ? (ho * wo) : ((ho * wo + tunable->nxb - 1) / tunable->nxb) * tunable->nxb; int gemm_m_per_block = tunable->gemm_m_per_block; @@ -392,7 +416,8 @@ class igemm_wrw_gtc_t { static int if_gemm_k_global_split(const args_t *arg, const int gemm_m_per_block, const int gemm_n_per_block, - const int gemm_k_per_block) + const int gemm_k_per_block, + const int data_byte) { int gemm_k_global_split = 0; int hi = arg->get_int("in_h"); @@ -414,6 +439,10 @@ class igemm_wrw_gtc_t { int group = arg->get_int("group_count"); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here + // int b = tunable->nxe == 0 ? (ho * wo) : ((ho * wo + tunable->nxb - 1) / tunable->nxb) * tunable->nxb; int gemm_m = k / group; int gemm_n = (c / group) * y * x; @@ -472,8 +501,13 @@ class igemm_wrw_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunables[0].precision); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here + int gemm_m_per_block = 0; int gemm_n_per_block = 0; int gemm_k_per_block = 0; @@ -541,7 +575,8 @@ class igemm_wrw_gtc_t { gemm_k_global_split = if_gemm_k_global_split(arg, gemm_m_per_block, gemm_n_per_block, - gemm_k_per_block); + gemm_k_per_block, + data_byte); nxb = 1; nxe = 1; @@ -660,8 +695,13 @@ class igemm_wrw_gtc_t { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int data_byte = utility_string_to_data_byte(tunable->precision); assert(c % group == 0 && k % group == 0); + int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte); + assert(splits != 0); + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; @@ -729,7 +769,7 @@ class igemm_wrw_gtc_t { hipEventCreate(&stop); // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config, start, stop)); @@ -741,7 +781,7 @@ class igemm_wrw_gtc_t { gpu_timer_t timer(NULL); timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, + HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config)); diff --git a/igemm/algo/igemm_bwd_gtc.py b/igemm/algo/igemm_bwd_gtc.py index cbdc282e..d8f079a9 100755 --- a/igemm/algo/igemm_bwd_gtc.py +++ b/igemm/algo/igemm_bwd_gtc.py @@ -695,6 +695,7 @@ def __init__(self, mc, outer): self.s_ka = sym_t("s_ka" ,0) self.s_bx = sym_t("s_bx" ,2) + self.s_by = sym_t("s_by" ,3) self.s_p_in = sym_t("s_p_in" ,4) self.s_p_wei = sym_t("s_p_wei" ,8) self.s_p_out = sym_t("s_p_out" ,12) @@ -814,7 +815,7 @@ def __init__(self, mc, outer): self.s_magic_1 = sym_t("s_magic_1" ,m0_num + 1) self.s_magic_2 = sym_t("s_magic_2" ,self.s_p_out.value + 2) self.s_magic_3 = sym_t("s_magic_3" ,self.s_p_out.value + 3) - self.s_magic_4 = sym_t("s_magic_4" ,3) + self.s_magic_4 = sym_t("s_magic_4" ,self.s_block_gtc_in1b.value) self.s_magic_5 = sym_t("s_magic_5" ,self.s_p_wei.value + 2) self.s_magic_6 = sym_t("s_magic_6" ,self.s_p_wei.value + 3) self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_p_in.value + 2) @@ -1271,6 +1272,7 @@ def get_kernel_code(self): kernel_code = amdgpu_kernel_code_t({ 'enable_sgpr_kernarg_segment_ptr' : 1, 'enable_sgpr_workgroup_id_x' : 1, + 'enable_sgpr_workgroup_id_y' : 1, 'enable_vgpr_workitem_id' : 0, 'workgroup_group_segment_byte_size' : self.tunable.lds_total, 'kernarg_segment_byte_size' : self.karg.get_count(), @@ -1604,6 +1606,22 @@ def emit_kernel_prologue(self): else: self._emit(f"s_lshr_b32 s[{s.s_wei_stride_c0()}], s[{s.s_c()}], {igemm_log2(n_c0)}") + # calculate batch split and accumulate the base pointer for input/output + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + self._emit(f"; k1e transform") if self.tunable.nxe != 0: if c_k1e == 1: diff --git a/igemm/algo/igemm_wrw_gtc.py b/igemm/algo/igemm_wrw_gtc.py index c64030d4..f19d812c 100644 --- a/igemm/algo/igemm_wrw_gtc.py +++ b/igemm/algo/igemm_wrw_gtc.py @@ -682,6 +682,7 @@ def __init__(self, mc, outer): self.s_ka = sym_t("s_ka" ,0) self.s_bx = sym_t("s_bx" ,2) + self.s_by = sym_t("s_by" ,3) self.s_p_in = sym_t("s_p_in" ,4) self.s_p_wei = sym_t("s_p_wei" ,8) self.s_p_out = sym_t("s_p_out" ,12) @@ -1242,6 +1243,7 @@ def get_kernel_code(self): kernel_code = amdgpu_kernel_code_t({ 'enable_sgpr_kernarg_segment_ptr' : 1, 'enable_sgpr_workgroup_id_x' : 1, + 'enable_sgpr_workgroup_id_y' : 1, 'enable_vgpr_workitem_id' : 0, 'workgroup_group_segment_byte_size' : self.tunable.lds_total, 'kernarg_segment_byte_size' : self.karg.get_count(), @@ -1474,6 +1476,23 @@ def emit_kernel_prologue(self): self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") self._emit(f"s_lshl_b32 s[{s.s_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") + if self.tunable.gemm_k_global_split != 0: + # calculate batch split and accumulate the base pointer for input/output + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + self._emit(f"; n1b transform") self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_in1(), v.v_gtc_in1b(), s.s_dim_b() if self.tunable.nxe != 0 else s.s_out_stride_k() , v.v_tmp(), s.s_tmp()))