diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config new file mode 100644 index 00000000..bfb639e1 --- /dev/null +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -0,0 +1,927 @@ +[codegen] +arch = 'gfx908' +code_object = 'cov3' +mode = 'flat' + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 2 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 2 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 +gemm_k_global_split = 1 + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 8 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 2, 1, 128] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 2, 1, 128] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + + +#--------------------------- 128x256 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 256 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 64 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 128x256 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 256 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 64 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + +#--------------------------- 128x256 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 256 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 64 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 +gemm_k_global_split = 1 + + +#--------------------------- 128x128 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 128x128 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x128 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + + +#--------------------------- 64x128 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 1, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_pass_through = 1 +tensor_a_thread_lengths = [1, 8, 1, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 2, 4, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + +# #--------------------------- 128x64 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 64 +# gemm_k_per_block = 32 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 1 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_pass_through = 1 +# tensor_a_thread_lengths = [1,16, 1, 1] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 2, 4, 32] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 + + +# #--------------------------- 128x64 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 64 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 32] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 32] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 8, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 2 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 8, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 16] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 256x64 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 32 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 8, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + +#--------------------------- 64x64 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 8 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 4 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 1, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 1, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 32x16 +[igemm_fwd_gtc] +gemm_m_per_block = 32 +gemm_n_per_block = 16 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 32x16 +[igemm_fwd_gtc] +gemm_m_per_block = 32 +gemm_n_per_block = 16 +gemm_k_per_block = 32 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +# #--------------------------- 32x16 +# [igemm_fwd_gtc] +# gemm_m_per_block = 32 +# gemm_n_per_block = 16 +# gemm_k_per_block = 8 +# wave_tile_m = 16 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 16 +# wave_step_n = 1 +# wave_repeat_n = 1 +# wave_tile_k = 1 +# tensor_a_thread_lengths = [1, 1, 2, 1] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 1, 1, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 + + +#--------------------------- 64x4 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 4 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 4 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 1,16, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1,16, 1, 4] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 1, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1,16, 1, 4] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +# #---------------------------------------------------------------------------------- +# +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 64 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# tensor_a_thread_lengths = [1, 4, 1, 4] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 2 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 1, 4] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 64 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# tensor_a_thread_lengths = [1, 4, 1, 4] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 +# +# +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 8 +# wave_tile_m = 64 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# tensor_a_thread_lengths = [1, 4, 1, 2] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 2, 1, 128] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 2, 1, 128] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# #--------------------------- 128x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 1, 2] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# #--------------------------- 128x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 1, 2] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 +# \ No newline at end of file diff --git a/driver/args.h b/driver/args.h index e873af82..6114fe97 100644 --- a/driver/args.h +++ b/driver/args.h @@ -157,14 +157,37 @@ class args_t { std::unordered_map input_map; }; +static inline std::string create_base_args(int argc, char *argv[]) { + if(argc < 2) + { + printf("Invalid Number of Input Arguments\n"); + exit(0); + } + + std::string arg = argv[1]; + + if(arg != "conv" && arg != "convfp16" && arg != "convint8" && arg != "--version") + { + printf("Invalid Base Input Argument\n"); + exit(0); + } + else if(arg == "-h" || arg == "--help" || arg == "-?") + exit(0); + else + return arg; +} + static inline args_t create_conv_args(int argc, char *argv[]) { - const std::string base("conv"); + const std::string base = create_base_args(argc, argv); if (argc >= 2 && argv[1] != base) { printf("not proper base arg name"); exit(1); } args_t args; + args.insert_arg("in_layout", 'I', "NCHW", "Input Layout (Default=NCHW)", "string"); + args.insert_arg("out_layout", 'O', "NCHW", "Output Layout (Default=NCHW)", "string"); + args.insert_arg("fil_layout", 'f', "NCHW", "Input Layout (Default=NCHW)", "string"); args.insert_arg("spatial_dim", '_', "2", "convolution spatial dimension (Default-2)", "int"); args.insert_arg("forw", 'F', "0", "Flag enables fwd, bwd, wrw convolutions" diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 733c7568..4f592ff3 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -57,18 +57,8 @@ # define IGEMM_GPU_NAIVE_CONV_HSACO "naive_conv.hsaco" # endif #else -# ifdef USE_XDNN -# include "xdnn_conv.h" -# define conv_fwd_nchw xdnn_conv_fwd_nchw -# define conv_bwd_nchw xdnn_conv_bwd_nchw -# define conv_wrw_nchw xdnn_conv_wrw_nchw -# else -# define NAIVE_CONV_THREADED -# include "naive_conv.h" -# define conv_fwd_nchw naive_conv_fwd_nchw -# define conv_bwd_nchw naive_conv_bwd_nchw -# define conv_wrw_nchw naive_conv_wrw_nchw -# endif +# define NAIVE_CONV_THREADED +# include "naive_conv.h" #endif static inline size_t conv_out_size(size_t in_size, size_t pad, size_t dilation, @@ -115,10 +105,11 @@ static int next_pow2(int n) { return n << 1; } typedef struct { - int return_code; - float duration_ms; - float gflops; - float efficiency; + int return_code {-1}; + int gks {0}; // this is to store the gks value after benchmarked + float duration_ms {FLT_MAX}; + float gflops {0}; + float efficiency {0}; std::string kernel_name; } result_t; @@ -132,7 +123,12 @@ typedef struct { } \ } while (0) -static inline double theoritical_fp32_gflops(double sclk_ghz, size_t cu, +#include "igemm_gtc_base.h" +#include "igemm_fwd_gtc_driver.h" +#include "igemm_bwd_gtc_driver.h" +#include "igemm_wrw_gtc_driver.h" + +static inline double theoritical_gflops(double sclk_ghz, size_t cu, size_t simd) { return 2 * sclk_ghz * cu * simd; } @@ -158,10 +154,63 @@ measured_fp32_conv_gflops(double time_ms, size_t n, size_t c, size_t hi, return flop / (time_ms * 1e6); } -#include "igemm_gtc_base.h" -#include "igemm_fwd_gtc_driver.h" -#include "igemm_bwd_gtc_driver.h" -#include "igemm_wrw_gtc_driver.h" + +static inline double get_theoritical_conv_flop(const args_t * conv_args) +{ + int hi = conv_args->get_int("in_h"); + int wi = conv_args->get_int("in_w"); + int n = conv_args->get_int("batchsize"); + int k = conv_args->get_int("out_channels"); + int c = conv_args->get_int("in_channels"); + + int stride_h = conv_args->get_int("conv_stride_h"); + int stride_w = conv_args->get_int("conv_stride_w"); + int dilation_h = conv_args->get_int("dilation_h"); + int dilation_w = conv_args->get_int("dilation_w"); + int pad_h = conv_args->get_int("pad_h"); + int pad_w = conv_args->get_int("pad_w"); + int y = conv_args->get_int("fil_h"); + int x = conv_args->get_int("fil_w"); + int ngroups = conv_args->get_int("group_count"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + return theoritical_fp32_conv_flop(n, c, hi, wi, k, y, x, stride_h, stride_w, + dilation_h, dilation_w, pad_h, pad_w, ngroups); +} + +static inline double get_theoritical_gpu_gflops(int sclk_mhz, driverDataType_t data_type) +{ + int num_cu; + int gcn_arch = 0; + int num_simd = 4 * 16; + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + gcn_arch = dev_prop.gcnArch; + if(gcn_arch >= 1000) + num_cu *= 2; + + int fp_factor = 1; + if(data_type == driverHalf){ + if(gcn_arch == 908) + fp_factor = 4; // xdlops + else + fp_factor = 2; // dlops + } + // else if(data_type == driverInt8){ + // if(gcn_arch == 908) + // fp_factor = 4; + // } + + if(gcn_arch == 908){ + num_simd = 4 * 32 ; // 4x miSIMD, 32x mac unit + } + + return theoritical_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd * fp_factor); +} #ifndef ABS #define ABS(x) ((x) > 0 ? (x) : -1 * (x)) @@ -264,13 +313,15 @@ static inline bool valid_vector(const float *ref, const float *pred, size_t n, double s1 = 0.0; int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + int igemm_valid_float = env_get_int("VALID_FLOAT", 1); size_t pp_err = 0; for (size_t i = 0; i < n; ++i) { - if(!(valid_float(ref[i]) && valid_float(pred[i]))){ - printf(" invalid float at %zu, ref:%f, pred:%f\n", i, ref[i], pred[i]); - return false; - } + if(igemm_valid_float) + if(!(valid_float(ref[i]) && valid_float(pred[i]))){ + printf(" invalid float at %zu, ref:%f, pred:%f\n", i, ref[i], pred[i]); + return false; + } double ri = (double)ref[i]; double pi = (double)pred[i]; double d = ri - pi; @@ -325,6 +376,17 @@ static inline double get_wrw_nrms() #endif } +static inline double get_nrms(std::string direction) +{ + if(direction == "fwd") + return get_fwd_nrms(); + if(direction == "bwd") + return get_bwd_nrms(); + if(direction == "wrw") + return get_wrw_nrms(); + assert(0); +} + void dump_arg(const args_t *arg) { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); @@ -349,16 +411,105 @@ void dump_arg(const args_t *arg) { pad_h, pad_w, ho, wo); } + +template +void launch_conv_driver(driver_t * driver, const args_t *conv_args, const std::vector & tunables, std::string direction, + void* device_input, void* device_weight, void* device_output, + pre_func_t && pre_func, post_func_t && post_func) +{ + int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ); + std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT); + int log_fastest_config = env_get_int("IGEMM_LOG_FASTEST_CONFIG", 0); + + double theo_conv_flop = get_theoritical_conv_flop(conv_args); + double theo_gpu_gflops = get_theoritical_gpu_gflops(sclk_mhz, driver->data_type); + + auto launch = [&](const igemm_gtc_tunable_t * tunable, int index) -> result_t { + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT){ + if(run_only_kernel != driver->get_kernel_name(tunable)){ + return result_t{}; + } + } + + printf("[%s:%2d] %s", direction.c_str(), index, driver->get_kernel_name(tunable).c_str()); + fflush(stdout); + + pre_func(); + + result_t result = driver->run(conv_args, tunable, device_input, device_weight, device_output); + + std::string gks_string = ""; + if(tunable->gemm_k_global_split){ + gks_string = "[" + std::to_string(result.gks) + "]"; + } + printf("%s, ", gks_string.c_str()); + + if (result.return_code != 0){ + printf("not applicatble\n"); + return result_t{}; + } + + double gflops = theo_conv_flop / (result.duration_ms * 1e6); + printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, + gflops / 1000 , (gflops / theo_gpu_gflops) * 100); + + post_func(); + + printf("\n"); + result.gflops = gflops; + result.efficiency = (gflops / theo_gpu_gflops) * 100; + return result; + }; + + if(driver->driver_mode == driver_mode_normal){ + result_t fastest_result_fwd; + fastest_result_fwd.duration_ms = FLT_MAX; + int fastest_id = -1; + + for(int i=0; idriver_mode == driver_mode_heuristic){ + igemm_gtc_tunable_t selected_tunable = driver->heuristic_select_kernel(conv_args); + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) + if(run_only_kernel != driver->get_kernel_name(&selected_tunable)){ + printf("heuristic selected tunable not match your request\n"); + return; + } + + result_t result = launch(&selected_tunable, 0); + }else{ + assert(0); + } +} + int main(int argc, char **argv) { char *hsaco = env_get_str("IGEMM_HSACO", IGEMM_HSACO); char *config_file = env_get_str("IGEMM_CONFIG_FILE", IGEMM_CONFIG_FILE); std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT); int warmup = env_get_int("IGEMM_WARMUP", WARMUP); int repeat = env_get_int("IGEMM_REPEAT", REPEAT); - int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ); - int log_fastest_config = env_get_int("IGEMM_LOG_FASTEST_CONFIG", 0); - int wrw_kernel_selection = env_get_int("IGEMM_LOG_SELECTED_CONFIG", 0); int assert_when_invalid = env_get_int("IGEMM_ASSERT_WHEN_INVALID", 0); + int verbose = env_get_int("IGEMM_VERBOSE", 0); + driver_mode_t driver_mode = static_cast(env_get_int("IGEMM_MODE", 0)); config_parser_t config_parser(config_file); auto content = config_parser.parse(); //content.dump(); @@ -373,13 +524,31 @@ int main(int argc, char **argv) { printf("no tunable specified, may not work\n"); return 0; } - // printf("tunables:%d\n", tunables.size()); + // printf("tunables:%d, hsaco:%s\n", tunables.size(), hsaco); hipModule_t module; HIP_CALL(hipModuleLoad(&module, hsaco)); + std::string base_arg = create_base_args(argc, argv); args_t conv_args = create_conv_args(argc, argv); // dump_arg(&conv_args); + driverDataType_t driver_data_type; + + if(base_arg == "conv"){ + driver_data_type = driverFloat; + } + else if(base_arg == "convfp16"){ + driver_data_type = driverHalf; + } + else if(base_arg == "convbf16") { + driver_data_type = driverBFloat16; + exit(0); + } + else if(base_arg == "convint8") { + driver_data_type = driverInt8; + } + else + exit(0); int hi = conv_args.get_int("in_h"); int wi = conv_args.get_int("in_w"); @@ -399,11 +568,19 @@ int main(int argc, char **argv) { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int forw = conv_args.get_int("forw"); + std::string in_layout = conv_args.get_str("in_layout"); + std::string out_layout = conv_args.get_str("out_layout"); + std::string fil_layout = conv_args.get_str("fil_layout"); int need_fwd = (forw == 0 ? 1 : (forw & 1 ? 1 : 0)); int need_bwd = (forw == 0 ? 1 : (forw & 2 ? 1 : 0)); int need_wrw = (forw == 0 ? 1 : (forw & 4 ? 1 : 0)); + assert(in_layout == out_layout && in_layout == fil_layout); // currently only support all layout is the same + assert(in_layout == "NCHW" || in_layout == "NHWC"); // currently only support these layout + assert((in_layout == "NCHW" && tunables[0].tensor_layout == "nchw") || + (in_layout == "NHWC" && tunables[0].tensor_layout == "nhwc")); // check pairs + // init host side float *host_input = (float *)malloc(static_cast(n) * c * hi * wi * sizeof(float)); float *host_weight = (float *)malloc(static_cast(k) * c * y * x * sizeof(float)); @@ -419,37 +596,6 @@ int main(int argc, char **argv) { int need_verify = conv_args.get_int("verify"); - // printf("fwd:%d, bwd:%d, wrw:%d, verify:%d\n",need_fwd, need_bwd, need_wrw, need_verify); - - int num_cu; - int num_simd = 64; // hard coded - int gcn_arch = 0; - { - hipDeviceProp_t dev_prop; - hipDevice_t dev; - HIP_CALL(hipGetDevice(&dev)); - HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; - gcn_arch = dev_prop.gcnArch; -#if 0 -#define P_DEVICE_PROP_INT(prop) \ - printf(#prop":%d\n", dev_prop.prop) - - - P_DEVICE_PROP_INT(clockRate); - P_DEVICE_PROP_INT(memoryClockRate); - P_DEVICE_PROP_INT(memoryBusWidth); - P_DEVICE_PROP_INT(major); - P_DEVICE_PROP_INT(minor); - P_DEVICE_PROP_INT(gcnArch); -#endif - } - if(gcn_arch == 908){ - num_simd = 4 * 32 ; // 4x miSIMD, 32x mac unit - } - double fp32_gflops = - theoritical_fp32_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd); - if (need_fwd){ result_t fastest_result_fwd; fastest_result_fwd.duration_ms = FLT_MAX; @@ -468,19 +614,34 @@ int main(int argc, char **argv) { static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); - - gpu_naive_conv_fwd_nchw_fp32(device_input, device_weight, device_output, + + if(in_layout == "NCHW") + gpu_naive_conv_fwd_nchw_fp32(device_input, device_weight, device_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + gpu_naive_conv_fwd_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else + assert(0); HIP_CALL(hipDeviceSynchronize()); HIP_CALL(hipMemcpy(host_output, device_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyDeviceToHost)); #else - conv_fwd_nchw(host_input, host_weight, host_output, n, wi, hi, c, + if(in_layout == "NCHW") + naive_conv_fwd_nchw(host_input, host_weight, host_output, n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + naive_conv_fwd_nhwc(host_input, host_weight, host_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else + assert(0); #endif device_output_to_host = (float *)malloc(static_cast(n) * k * ho * wo * sizeof(float)); } @@ -490,34 +651,16 @@ int main(int argc, char **argv) { HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); - igemm_fwd_gtc_t conv_fwd_driver; - double nrms = get_fwd_nrms(); - for (int i = 0; i < tunables.size(); i++) { - igemm_gtc_tunable_t *tunable = &tunables[i]; - if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) - if(run_only_kernel != conv_fwd_driver.get_kernel_name(tunable)) - continue; - - printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(tunable).c_str()); - fflush(stdout); + igemm_fwd_gtc_t conv_fwd_driver(module, driver_mode, driver_data_type, warmup, repeat, verbose); - //if (need_verify) - // HIP_CALL(hipMemset(device_output, 0, - // n * c * ho * wo * sizeof(float))); - result_t result = - conv_fwd_driver.run(&conv_args, tunable, module, device_input, - device_weight, device_output, warmup, repeat); - if (result.return_code != 0){ - printf("not applicatble\n"); - continue; - } + auto fwd_pre = [&](){ + if (need_verify) + HIP_CALL(hipMemset(device_output, 0, static_cast(n) * k * ho * wo * sizeof(float))); + }; - double gflops = measured_fp32_conv_gflops( - result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, - dilation_h, dilation_w, pad_h, pad_w, ngroups); - printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, - gflops / 1000 , (gflops / fp32_gflops) * 100); + auto fwd_post = [&](){ if (need_verify) { + double nrms = get_fwd_nrms(); HIP_CALL(hipMemcpy(device_output_to_host, device_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyDeviceToHost)); @@ -526,26 +669,10 @@ int main(int argc, char **argv) { printf(", valid:%s", is_valid ? "y" : "n"); if(assert_when_invalid) assert(is_valid); } - printf("\n"); - if(result.duration_ms < fastest_result_fwd.duration_ms){ - fastest_result_fwd = result; - fastest_result_fwd.gflops = (float)gflops; - fastest_result_fwd.efficiency = (gflops / fp32_gflops) * 100; - fastest_id = i; - } - } - if(log_fastest_config){ - dump_arg(&conv_args); - if(fastest_id == -1) - printf(" fastest: no suitable kernel\n"); - else - printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", - fastest_id, - fastest_result_fwd.kernel_name.c_str(), - fastest_result_fwd.duration_ms, - fastest_result_fwd.gflops / 1000, - fastest_result_fwd.efficiency); - } + }; + + launch_conv_driver(&conv_fwd_driver, &conv_args, tunables, "fwd", device_input, device_weight, device_output, fwd_pre, fwd_post); + if (need_verify) free(device_output_to_host); } @@ -567,18 +694,33 @@ int main(int argc, char **argv) { static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); - gpu_naive_conv_bwd_nchw_fp32(device_input, device_weight, device_output, + if(in_layout == "NCHW") + gpu_naive_conv_bwd_nchw_fp32(device_input, device_weight, device_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + gpu_naive_conv_bwd_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else + assert(0); HIP_CALL(hipDeviceSynchronize()); HIP_CALL(hipMemcpy(host_input, device_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyDeviceToHost)); #else - conv_bwd_nchw(host_input, host_weight, host_output, n, + if(in_layout == "NCHW") + naive_conv_bwd_nchw(host_input, host_weight, host_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + naive_conv_bwd_nhwc(host_input, host_weight, host_output, n, + wi, hi, c, k, x, y, pad_w, + pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else + assert(0); #endif device_input_to_host = (float *)malloc(static_cast(n) * c * hi * wi * sizeof(float)); // printf("len:%d\n", n * c * hi * wi * sizeof(float) ); @@ -589,35 +731,16 @@ int main(int argc, char **argv) { HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + igemm_bwd_gtc_t conv_bwd_driver(module, driver_mode, driver_data_type, warmup, repeat, verbose); - igemm_bwd_gtc_t conv_bwd_driver; - double nrms = get_bwd_nrms(); - for (int i = 0; i < tunables.size(); i++) { - igemm_gtc_tunable_t *tunable = &tunables[i]; - if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) - if(run_only_kernel != conv_bwd_driver.get_kernel_name(tunable)) - continue; - - printf("[bwd:%2d] %s, ", i, conv_bwd_driver.get_kernel_name(tunable).c_str()); - fflush(stdout); - + auto bwd_pre = [&](){ if (need_verify) - HIP_CALL(hipMemset(device_input, 0x7f, - static_cast(n) * c * hi * wi * sizeof(float))); // 0x7f7f7f7f ~= 7.41e+28, a very large number - result_t result = - conv_bwd_driver.run(&conv_args, tunable, module, device_input, - device_weight, device_output, warmup, repeat); - if (result.return_code != 0){ - printf("not applicatble\n"); - continue; - } + HIP_CALL(hipMemset(device_input, 0x7f, static_cast(n) * c * hi * wi * sizeof(float))); // 0x7f7f7f7f ~= 7.41e+28, a very large number + }; - double gflops = measured_fp32_conv_gflops( - result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, - dilation_h, dilation_w, pad_h, pad_w, ngroups); - printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, - gflops / 1000 , (gflops / fp32_gflops) * 100); + auto bwd_post = [&](){ if (need_verify) { + double nrms = get_bwd_nrms(); HIP_CALL(hipMemcpy(device_input_to_host, device_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyDeviceToHost)); @@ -625,34 +748,15 @@ int main(int argc, char **argv) { static_cast(n) * c * hi * wi, nrms); printf(", valid:%s", is_valid ? "y" : "n"); if(assert_when_invalid) assert(is_valid); - // if (!is_valid) { - // printf("\n"); - // break; - // } } - printf("\n"); - if(result.duration_ms < fastest_result_bwd.duration_ms){ - fastest_result_bwd = result; - fastest_result_bwd.gflops = (float)gflops; - fastest_result_bwd.efficiency = (gflops / fp32_gflops) * 100; - fastest_id = i; - } - } - if(log_fastest_config){ - dump_arg(&conv_args); - if(fastest_id == -1) - printf(" fastest: no suitable kernel\n"); - else - printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", - fastest_id, - fastest_result_bwd.kernel_name.c_str(), - fastest_result_bwd.duration_ms, - fastest_result_bwd.gflops / 1000, - fastest_result_bwd.efficiency); - } + }; + + launch_conv_driver(&conv_bwd_driver, &conv_args, tunables, "bwd", device_input, device_weight, device_output, bwd_pre, bwd_post); + if (need_verify) free(device_input_to_host); } + if (need_wrw){ float *device_weight_to_host = NULL; if (need_verify) { @@ -666,18 +770,33 @@ int main(int argc, char **argv) { static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_output, host_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); - gpu_naive_conv_wrw_nchw_fp32(device_input, device_weight, device_output, + if(in_layout == "NCHW") + gpu_naive_conv_wrw_nchw_fp32(device_input, device_weight, device_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + gpu_naive_conv_wrw_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else + assert(0); HIP_CALL(hipDeviceSynchronize()); HIP_CALL(hipMemcpy(host_weight, device_weight, static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x * sizeof(float), hipMemcpyDeviceToHost)); #else - conv_wrw_nchw(host_input, host_weight, host_output, n, + if(in_layout == "NCHW") + naive_conv_wrw_nchw(host_input, host_weight, host_output, n, + wi, hi, c, k, x, y, pad_w, + pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + naive_conv_wrw_nhwc(host_input, host_weight, host_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else + assert(0); #endif device_weight_to_host = (float *)malloc(static_cast(k) * c * y * x * sizeof(float)); // printf("len:%d\n", k * c * y * x * sizeof(float)); @@ -688,38 +807,34 @@ int main(int argc, char **argv) { HIP_CALL(hipMemcpy(device_output, host_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); -#if 0 - printf("input\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check*hi*wi, host_input[i_check*hi*wi], ((int *)host_input)[i_check*hi*wi]); - } - printf("output\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check*ho*wo, host_output[i_check*ho*wo], ((int *)host_output)[i_check*ho*wo]); - } - printf("input\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check, host_input[i_check], ((int *)host_input)[i_check]); - } - printf("output\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check, host_output[i_check], ((int *)host_output)[i_check]); - } - printf("workspace debug end \r\n"); -#endif - igemm_wrw_gtc_t conv_wrw_driver; - float min_duration = 10000000.0f; - float selected_duration = 10000000.0f; - double nrms = get_wrw_nrms(); - std::string kernel_name; - std::string selected_kernel; + igemm_wrw_gtc_t conv_wrw_driver(module, driver_mode, driver_data_type, warmup, repeat, verbose); + + auto wrw_pre = [&](){ + if (need_verify) + HIP_CALL(hipMemset(device_weight, 0, static_cast(k) * c * y * x * sizeof(float))); + }; + auto wrw_post = [&](){ + if (need_verify) { + double nrms = get_wrw_nrms(); + HIP_CALL(hipMemcpy(device_weight_to_host, device_weight, + static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x * sizeof(float), + hipMemcpyDeviceToHost)); + bool is_valid = valid_vector(host_weight, device_weight_to_host, + static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x, nrms); + printf(", valid:%s", is_valid ? "y" : "n"); + if(assert_when_invalid) assert(is_valid); + } + }; + + launch_conv_driver(&conv_wrw_driver, &conv_args, tunables, "wrw", device_input, device_weight, device_output, wrw_pre, wrw_post); + + if (need_verify) + free(device_weight_to_host); + +#if 0 selected_kernel = conv_wrw_driver.select_kernel(&conv_args, tunables); int min_grid = 0; @@ -801,6 +916,7 @@ int main(int argc, char **argv) { } if (need_verify) free(device_weight_to_host); +#endif } free(host_input); diff --git a/driver/igemm_bwd_gtc_driver.h b/driver/igemm_bwd_gtc_driver.h index e685482d..52bca0d4 100755 --- a/driver/igemm_bwd_gtc_driver.h +++ b/driver/igemm_bwd_gtc_driver.h @@ -40,9 +40,9 @@ // #define IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL 1 typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -163,14 +163,13 @@ static void dump_bwd_karg(igemm_bwd_gtc_karg_t * karg){ std::cout<fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_MAC || tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS){ return tunable->gemm_m_level0_cluster * tunable->gemm_n_level0_cluster * tunable->gemm_m_level1_cluster * tunable->gemm_n_level1_cluster; @@ -179,10 +178,9 @@ class igemm_bwd_gtc_t { int waves_per_n = tunable->gemm_n_per_block / (tunable->wave_tile_n * tunable->wave_step_n * tunable->wave_repeat_n); return waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE; } - } - int get_grid_size(const args_t *arg, - const igemm_gtc_tunable_t *tunable) { + size_t get_grid_size(const args_t *arg, + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); int n = arg->get_int("batchsize"); @@ -254,7 +252,7 @@ class igemm_bwd_gtc_t { } bool tunable_is_valid(const args_t *arg, - const igemm_gtc_tunable_t *tunable) + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); @@ -363,8 +361,7 @@ class igemm_bwd_gtc_t { } result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, - hipModule_t module, float *p_in, float *p_wei, float *p_out, - int warmup, int repeat) { + void *p_in, void *p_wei, void *p_out) override { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; @@ -498,8 +495,8 @@ class igemm_bwd_gtc_t { if(y < stride_h || x < stride_w || dilation_h != 1 || dilation_w != 1) need_set_zero = true; - int block_size = get_block_size(tunable); - int grid_size = get_grid_size(arg, tunable); + size_t block_size = get_block_size(tunable); + size_t grid_size = get_grid_size(arg, tunable); #ifdef IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL igemm_upsampling_clear_karg_t ukarg; @@ -547,36 +544,15 @@ class igemm_bwd_gtc_t { HIP_CALL( hipModuleGetFunction(&upsampling_clear_kernel_func, module, upsampling_clear_kernel_name.c_str())); #endif + result_t result; - auto launch_bwd = [&]() -> float{ - float ms_total = .0; - if(need_set_zero){ - float ms = .0; - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); -#ifdef IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &ukarg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &ukarg_size, - HIP_LAUNCH_PARAM_END}; - HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1, - u_block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); -#else - HIP_CALL(hipDeviceSynchronize()); - HIP_CALL(hipEventRecord( start, NULL )); - hipMemset(p_in, 0, n*c*hi*wi*sizeof(float)); - HIP_CALL(hipEventRecord( stop, NULL )); -#endif - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); - - ms_total += ms; - } - + if(tunable->multihead){ + // TODO: + }else{ + std::vector kernels; + std::vector kargs; + kargs.reserve(num_of_gemm); // CAUSION! we do not want this vector to be reallocated and move to other place + int valid_kernel_index = 0; for(int gemm_id = 0; gemm_id < num_of_gemm; gemm_id++){ int i_y_tilda = gemm_id / x_tilda; int i_x_tilda = gemm_id % x_tilda; @@ -585,154 +561,46 @@ class igemm_bwd_gtc_t { int gemm_k = (k / group) * y_dot_slice * x_dot_slice; bool is_gemm_not_empty = gemm_k > 0 && y_dot_slice > 0 && x_dot_slice > 0; - - karg.dtile_iy = i_y_tilda; - karg.dtile_ix = i_x_tilda; - karg.dslice_y = y_dot_slice; - karg.dslice_x = x_dot_slice; + if(is_gemm_not_empty){ + kargs.push_back(karg); + kargs[valid_kernel_index].dtile_iy = i_y_tilda; + kargs[valid_kernel_index].dtile_ix = i_x_tilda; + kargs[valid_kernel_index].dslice_y = y_dot_slice; + kargs[valid_kernel_index].dslice_x = x_dot_slice; #if USE_MAGIC_DIV - magic_div_u32_t mdiv_0 = is_gemm_not_empty ? magic_div_u32_gen(y_dot_slice * x_dot_slice) : magic_div_u32_t({0, 0}); - magic_div_u32_t mdiv_1 = is_gemm_not_empty ? magic_div_u32_gen(x_dot_slice) : magic_div_u32_t({0, 0}); - karg.magic_0 = mdiv_0.magic; - karg.magic_1 = mdiv_1.magic; + magic_div_u32_t mdiv_0 = is_gemm_not_empty ? magic_div_u32_gen(y_dot_slice * x_dot_slice) : magic_div_u32_t({0, 0}); + magic_div_u32_t mdiv_1 = is_gemm_not_empty ? magic_div_u32_gen(x_dot_slice) : magic_div_u32_t({0, 0}); + kargs[valid_kernel_index].magic_0 = mdiv_0.magic; + kargs[valid_kernel_index].magic_1 = mdiv_1.magic; - karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); - karg.shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); + kargs[valid_kernel_index].shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); + kargs[valid_kernel_index].shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); #endif - // printf("start launch id:%d(%d), block:%d, grid:%d\n", gemm_id, is_gemm_not_empty?1:0, block_size, grid_size); - // dump_bwd_karg(&karg); + // printf("start launch id:%d(%d), block:%d, grid:%d\n", gemm_id, is_gemm_not_empty?1:0, block_size, grid_size); + // dump_bwd_karg(&kargs[valid_kernel_index]); - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; + kernels.push_back({kernel_func, (void*)&kargs[valid_kernel_index], karg_size, std::vector{grid_size * block_size, 1, 1}, std::vector{block_size, 1, 1}}); - if(is_gemm_not_empty){ -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - timer.stop(); - ms = timer.duration(); -#endif + valid_kernel_index++; } - ms_total += ms; } - return ms_total; - }; - - auto launch_bwd_multihead = [&]() -> float{ - float ms_total = .0; - if(need_set_zero){ - float ms = .0; - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); -#ifdef IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &ukarg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &ukarg_size, - HIP_LAUNCH_PARAM_END}; - HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1, - u_block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); -#else - HIP_CALL(hipDeviceSynchronize()); - HIP_CALL(hipEventRecord( start, NULL )); - hipMemset(p_in, 0, n*c*hi*wi*sizeof(float)); - HIP_CALL(hipEventRecord( stop, NULL )); -#endif - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); - ms_total += ms; - } - // if 1x1 and stride/dilation > 1, will have empty gemms which will waste launch grid. better ignore that case at runtime - int origin_grid_size = grid_size/num_of_gemm; - karg.dtile_iy = origin_grid_size; - karg.dtile_ix = x_dot | (y_dot<<16); - karg.dslice_y = y % y_dot; - karg.dslice_x = x % x_dot; - // printf("start launch id:%d(%d), block:%d, grid:%d\n", gemm_id, is_gemm_not_empty?1:0, block_size, grid_size); - // dump_bwd_karg(&karg); - - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - timer.stop(); - ms = timer.duration(); -#endif - ms_total += ms; - return ms_total; - }; - - auto launch_bwd_driver = [&](){ - if(tunable->multihead) - return launch_bwd_multihead(); - else - return launch_bwd(); - }; - - for (int i = 0; i < warmup; i++) { - launch_bwd_driver(); - } - std::vector duration_list; - for (int i = 0; i < repeat; i++) { - float d = launch_bwd_driver(); - duration_list.push_back(d); + assert(kernels.size() == valid_kernel_index && kargs.size() == valid_kernel_index); + auto bwd_epilog = need_set_zero ? + std::function{[&]() -> float{ + hipMemset(p_in, 0, n*c*hi*wi*sizeof(float)); + return .0; + }} : + std::function{[&]() -> float{ + return .0; + }}; + float ms = igemm_launch_kernels_with_epilog(kernels, bwd_epilog, warmup, repeat); + + result.return_code = 0; + result.duration_ms = ms; } - // remove min and max from list, then do average - auto imin = std::min_element(begin(duration_list), end(duration_list)); - duration_list.erase(imin); - auto imax = std::max_element(begin(duration_list), end(duration_list)); - duration_list.erase(imax); - assert(duration_list.size() == (repeat - 2)); - float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); - usleep(1000 * 5); - - result_t result; - result.return_code = 0; - result.duration_ms = avg_duration; - result.kernel_name = kernel_name; return result; } }; diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 4ce61352..514a0c8c 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -36,11 +36,15 @@ #include #include #include +#include + +//#define GEMM_K_GLOBAL_SPLIT 3 +#define MAX_GEMM_K_SPLITS 8 typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -67,10 +71,42 @@ typedef struct { uint32_t magic_6; // denom: n*b*k / (m_per_block*n_per_block) uint32_t shift_pack_0; uint32_t shift_pack_1; - uint32_t __pack_0; + uint32_t ks; #endif } __attribute__((packed)) igemm_fwd_gtc_karg_t; +typedef struct { + void *p_in; + void *p_wei; + void *p_out; + int hi; + int wi; + int n; + int k; // this is indeed k_per_group + int c; // this is indeed c_per_group + int ho; + int wo; + int stride_h; + int stride_w; + int dilation_h; + int dilation_w; + int pad_h; + int pad_w; + int y; + int x; + int group; +#if USE_MAGIC_DIV + uint32_t magic_0; // denom: (gemm_n + n_per_block - 1) / n_per_block + uint32_t magic_1; // denom: ho*wo + uint32_t magic_2; // denom: wo + uint32_t magic_3; // denom: (gemm_m/m_per_block) * (gemm_n/n_per_block) + uint32_t shift_pack_0; + uint32_t ks; +#endif +} __attribute__((packed)) igemm_fwd_gtc_nhwc_karg_t; + +#define IGEMM_FWD_GTC_MAX_KARG_SIZE 160 + static void dump_fwd_karg(igemm_fwd_gtc_karg_t * karg){ std::cout<<"p_in:" <p_in<<","; std::cout<<"p_wei:" <p_wei<<","; @@ -105,14 +141,44 @@ static void dump_fwd_karg(igemm_fwd_gtc_karg_t * karg){ std::cout<p_in<<","; + std::cout<<"p_wei:" <p_wei<<","; + std::cout<<"p_out:" <p_out<<","; + std::cout<<"hi:" <hi<<","; + std::cout<<"wi:" <wi<<","; + std::cout<<"n:" <n<<","; + std::cout<<"k:" <k<<","; + std::cout<<"c:" <c<<","; + std::cout<<"ho:" <ho<<","; + std::cout<<"wo:" <wo<<","; + std::cout<<"stride_h:" <stride_h<<","; + std::cout<<"stride_w:" <stride_w<<","; + std::cout<<"dilation_h:" <dilation_h<<","; + std::cout<<"dilation_w:" <dilation_w<<","; + std::cout<<"pad_h:" <pad_h<<","; + std::cout<<"pad_w:" <pad_w<<","; + std::cout<<"y:" <y<<","; + std::cout<<"x:" <x<<","; + std::cout<<"group:" <group<<","; +#if USE_MAGIC_DIV + std::cout<<"magic_0:" <magic_0<<","; + std::cout<<"magic_1:" <magic_1<<","; + std::cout<<"magic_2:" <magic_2<<","; + std::cout<<"magic_3:" <magic_3<<","; + std::cout<<"shift_pack_0:" <shift_pack_0<<","; +#endif + std::cout<<"ks:" <ks; + std::cout<fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_MAC || tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS){ return tunable->gemm_m_level0_cluster * tunable->gemm_n_level0_cluster * tunable->gemm_m_level1_cluster * tunable->gemm_n_level1_cluster; @@ -122,8 +188,9 @@ class igemm_fwd_gtc_t { return waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE; } } - int get_grid_size(const args_t *arg, - const igemm_gtc_tunable_t *tunable) { + // return grid size without consideration of split k + size_t get_grid_size(const args_t *arg, + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); int n = arg->get_int("batchsize"); @@ -142,74 +209,39 @@ class igemm_fwd_gtc_t { int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); - int splits = split_batch_size(arg, tunable); + size_t splits = igemm_split_batch_size(arg, utility_string_to_data_byte(tunable->precision)); n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int nxe = tunable->nxe; int nxb = tunable->nxb; - int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; - + int b = ho * wo; + + if(tunable->tensor_layout == "nchw") + b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + + int gemm_m = 0; + int gemm_n = 0; + + if(tunable->tensor_layout == "nchw"){ + gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + gemm_n = n * b; + }else if (tunable->tensor_layout == "nhwc"){ + gemm_m = n * b; + // gemm_n = ((k/group + gemm_n_per_block -1)/gemm_n_per_block) * gemm_n_per_block; + gemm_n = k / group; + }else{ + assert(false); + } size_t grid_size = static_cast(group) * utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * - utility_integer_divide_ceil(gemm_n, gemm_n_per_block); + utility_integer_divide_ceil(gemm_n, gemm_n_per_block); assert(grid_size <= 0xffffffffUL); return grid_size; } - // this is to support big tensor > 4G. need to decide how many splits needed - // return the number of splits - int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable) - { - int hi = arg->get_int("in_h"); - int wi = arg->get_int("in_w"); - int n = arg->get_int("batchsize"); - int k = arg->get_int("out_channels"); - int c = arg->get_int("in_channels"); - - int stride_h = arg->get_int("conv_stride_h"); - int stride_w = arg->get_int("conv_stride_w"); - int dilation_h = arg->get_int("dilation_h"); - int dilation_w = arg->get_int("dilation_w"); - int pad_h = arg->get_int("pad_h"); - int pad_w = arg->get_int("pad_w"); - int y = arg->get_int("fil_h"); - int x = arg->get_int("fil_w"); - int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); - int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); - - int data_byte = utility_string_to_data_byte(tunable->precision); - size_t image_size_input = static_cast(c) * hi * wi * data_byte; - size_t image_size_output = static_cast(k) * ho * wo * data_byte; - size_t size_4g = 0xffffffffUL; - if(image_size_input >= size_4g || image_size_output >= size_4g) - return 0; - - size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; - size_t splited_n = size_4g / image_size; - - // round up splits, we must match - // 1. splited_n * image_size < size_4g - // 2. n % splited_n == 0 - // if(splited_n >= n) - // return 1; - assert(splited_n != 0); - while(splited_n >= 1){ - // printf("n:%d, splited_n:%d\n", n, splited_n); - if(n % splited_n == 0) - break; - splited_n--; - } - - assert(splited_n * image_size < size_4g && n % splited_n == 0); - return n / splited_n; - } - bool tunable_is_valid(const args_t *arg, - const igemm_gtc_tunable_t *tunable) + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); @@ -231,7 +263,7 @@ class igemm_fwd_gtc_t { assert(c % group == 0 && k % group == 0); - int splits = split_batch_size(arg, tunable); + size_t splits = igemm_split_batch_size(arg, utility_string_to_data_byte(tunable->precision)); if(splits == 0){ printf("image size (c*h*w) is bigger than 4g, which is not supported now\n"); return false; @@ -244,66 +276,125 @@ class igemm_fwd_gtc_t { int nxe = tunable->nxe; int nxb = tunable->nxb; - int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; - int gemm_k = (c / group) * y * x; + int b = ho * wo; + if(tunable->tensor_layout == "nchw") + b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 bool unit_conv = (x==1)&&(y==1)&&(stride_h==1)&&(stride_w==1)&&(dilation_h==1)&&(dilation_w==1)&&(pad_h==0)&&(pad_w==0); - // support pad to modulo, hence only check when nxe is 0 - if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) - { - return false; - } - - if(gemm_n_per_block % tunable->nxb != 0){ - //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); - return false; - } - - if(n % (gemm_n_per_block / tunable->nxb) != 0){ - //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); - return false; - } - - if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ - return false; - } - - if((nxe == 0) && !unit_conv){ - return false; - } - - // input vector load limitation, n1b - if(tunable->tensor_b_thread_lengths[3] > 1 && ( - !unit_conv || - unit_conv && (hi * wi) % tunable->tensor_b_thread_lengths[3] != 0)) { - return false; - } - - // weight vector load limitation, c1e - if(tunable->tensor_a_thread_lengths[1] > 1 && - gemm_k % tunable->tensor_a_thread_lengths[1] != 0){ - return false; - } - - // if tb_c1e > 1, only 1x1 case is runable, it can not check gemm_k_padding either. - if(tunable->tensor_b_thread_lengths[1] > 1 && (( x !=1 || y != 1)||(gemm_k % gemm_k_per_block != 0))){ - return false; - } - - // if t_c0 > 1, need to check gemmk per block - if(tunable->tensor_b_thread_lengths[0] > 1 && (gemm_k % gemm_k_per_block != 0)){ - return false; + if(tunable->tensor_layout == "nchw"){ + int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + int gemm_n = n * b; + int gemm_k = (c / group) * y * x; + + // support pad to modulo, hence only check when nxe is 0 + if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) + { + return false; + } + + if(gemm_n_per_block % tunable->nxb != 0){ + //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + return false; + } + + if(n % (gemm_n_per_block / tunable->nxb) != 0){ + //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + return false; + } + + if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ + return false; + } + + if((nxe == 0) && !unit_conv){ + return false; + } + + // input vector load limitation, n1b + if(tunable->tensor_b_thread_lengths[3] > 1 && ( + !unit_conv || + unit_conv && (hi * wi) % tunable->tensor_b_thread_lengths[3] != 0)) { + return false; + } + + // weight vector load limitation, c1e + if(tunable->tensor_a_thread_lengths[1] > 1 && + gemm_k % tunable->tensor_a_thread_lengths[1] != 0){ + return false; + } + + // if tb_c1e > 1, only 1x1 case is runable, it can not check gemm_k_padding either. + if(tunable->tensor_b_thread_lengths[1] > 1 && (( x !=1 || y != 1)||(gemm_k % gemm_k_per_block != 0))){ + return false; + } + + // if t_c0 > 1, need to check gemmk per block + if(tunable->tensor_b_thread_lengths[0] > 1 && (gemm_k % gemm_k_per_block != 0)){ + return false; + } + }else if(tunable->tensor_layout == "nhwc"){ + //int gemm_m = n * b; + // int gemm_n = ((k/group + gemm_n_per_block -1)/gemm_n_per_block) * gemm_n_per_block; + //int gemm_n = k / group; + //int gemm_k = (c / group) * y * x; + + // support pad to modulo, hence only check when nxe is 0 + //if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) + //{ + // return false; + //} + + if(((c >> tunable->gemm_k_global_split) / group) % gemm_k_per_block != 0) + return false; + + // if(gemm_m_per_block % tunable->nxb != 0){ + // //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + // return false; + // } + + // if(n % (gemm_m_per_block / tunable->nxb) != 0){ + // //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + // return false; + // } + + // if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ + // return false; + // } + + if((nxe == 0) && !unit_conv){ + return false; + } + + // input vector load limitation, n1b + //if(tunable->tensor_a_thread_lengths[3] > 1 && ( + // !unit_conv || + // unit_conv && (hi * wi) % tunable->tensor_a_thread_lengths[3] != 0)) { + // return false; + //} + + // // weight vector load limitation, c1e + // if(tunable->tensor_a_thread_lengths[1] > 1 && + // gemm_k % tunable->tensor_a_thread_lengths[1] != 0){ + // return false; + // } + + // // if tb_c1e > 1, only 1x1 case is runable, it can not check gemm_k_padding either. + // if(tunable->tensor_b_thread_lengths[1] > 1 && (( x !=1 || y != 1)||(gemm_k % gemm_k_per_block != 0))){ + // return false; + // } + + // // if t_c0 > 1, need to check gemmk per block + // if(tunable->tensor_b_thread_lengths[0] > 1 && (gemm_k % gemm_k_per_block != 0)){ + // return false; + // } + }else{ + assert(0); } return true; } - result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, - hipModule_t module, float *p_in, float *p_wei, float *p_out, - int warmup, int repeat) { + result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, void *p_in, void *p_wei, void *p_out) override { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; @@ -331,7 +422,7 @@ class igemm_fwd_gtc_t { assert(c % group == 0 && k % group == 0); - int splits = split_batch_size(arg, tunable); + size_t splits = igemm_split_batch_size(arg, utility_string_to_data_byte(tunable->precision)); n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; @@ -339,135 +430,172 @@ class igemm_fwd_gtc_t { int gemm_k_per_block = tunable->gemm_k_per_block; int nxe = tunable->nxe; int nxb = tunable->nxb; - int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - - igemm_fwd_gtc_karg_t karg; - size_t karg_size = sizeof(karg); - karg.p_in = p_in; - karg.p_wei = p_wei; - karg.p_out = p_out; - karg.hi = hi; - karg.wi = wi; - karg.n = n; - karg.k = k / group; - karg.c = c / group; - karg.ho = ho; - karg.wo = wo; - - karg.stride_h = stride_h; - karg.stride_w = stride_w; - karg.dilation_h = dilation_h; - karg.dilation_w = dilation_w; - karg.pad_h = pad_h; - karg.pad_w = pad_w; - karg.y = y; - karg.x = x; - karg.group = group; - - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; + int b = ho * wo; + if(tunable->tensor_layout == "nchw") + b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + + size_t karg_size = 0; + uint8_t karg_buffer[IGEMM_FWD_GTC_MAX_KARG_SIZE]; + + if(tunable->tensor_layout == "nchw"){ + igemm_fwd_gtc_karg_t karg; + karg.p_in = p_in; + karg.p_wei = p_wei; + karg.p_out = p_out; + karg.hi = hi; + karg.wi = wi; + karg.n = n; + karg.k = k / group; + karg.c = c / group; + karg.ho = ho; + karg.wo = wo; + karg.stride_h = stride_h; + karg.stride_w = stride_w; + karg.dilation_h = dilation_h; + karg.dilation_w = dilation_w; + karg.pad_h = pad_h; + karg.pad_w = pad_w; + karg.y = y; + karg.x = x; + karg.group = group; #if USE_MAGIC_DIV - { - // init magic division parameters - uint32_t nb_n0 = tunable->tensor_b_cluster_lengths[2] * tunable->tensor_b_thread_lengths[2]; - uint32_t nb_n1b = tunable->tensor_b_cluster_lengths[3] * tunable->tensor_b_thread_lengths[3]; - uint32_t unmerge_sub_n = gemm_n_per_block / nxb; - uint32_t unmerge_sub_n1 = tunable->gemm_n_unmerge_cluster == 0 ? unmerge_sub_n / nb_n0 : unmerge_sub_n; - - magic_div_u32_t mdiv_0 = magic_div_u32_gen(tunable->source_access_order == 0 ? ((n * b) / gemm_n_per_block) : ((gemm_m) / gemm_m_per_block)); - magic_div_u32_t mdiv_1 = magic_div_u32_gen(tunable->gemm_n_unmerge_cluster == 0 ? - b * unmerge_sub_n1 / nb_n1b : - (n / nb_n0) * b / nb_n1b ); - magic_div_u32_t mdiv_2 = magic_div_u32_gen(y * x); - magic_div_u32_t mdiv_3 = magic_div_u32_gen(x); - magic_div_u32_t mdiv_4 = magic_div_u32_gen(b); - magic_div_u32_t mdiv_5 = magic_div_u32_gen(wo); - magic_div_u32_t mdiv_6 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * - utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); + int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + int gemm_n = n * b; + { + // init magic division parameters + uint32_t nb_n0 = tunable->tensor_b_cluster_lengths[2] * tunable->tensor_b_thread_lengths[2]; + uint32_t nb_n1b = tunable->tensor_b_cluster_lengths[3] * tunable->tensor_b_thread_lengths[3]; + uint32_t unmerge_sub_n = gemm_n_per_block / nxb; + uint32_t unmerge_sub_n1 = tunable->gemm_n_unmerge_cluster == 0 ? unmerge_sub_n / nb_n0 : unmerge_sub_n; + + magic_div_u32_t mdiv_0 = magic_div_u32_gen(tunable->source_access_order == 0 ? ((n * b) / gemm_n_per_block) : ((gemm_m) / gemm_m_per_block)); + magic_div_u32_t mdiv_1 = magic_div_u32_gen(tunable->gemm_n_unmerge_cluster == 0 ? + b * unmerge_sub_n1 / nb_n1b : + (n / nb_n0) * b / nb_n1b ); + magic_div_u32_t mdiv_2 = magic_div_u32_gen(y * x); + magic_div_u32_t mdiv_3 = magic_div_u32_gen(x); + magic_div_u32_t mdiv_4 = magic_div_u32_gen(b); + magic_div_u32_t mdiv_5 = magic_div_u32_gen(wo); + magic_div_u32_t mdiv_6 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * + utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); + + karg.magic_0 = mdiv_0.magic; + karg.magic_1 = mdiv_1.magic; + karg.magic_2 = mdiv_2.magic; + karg.magic_3 = mdiv_3.magic; + karg.magic_4 = mdiv_4.magic; + karg.magic_5 = mdiv_5.magic; + karg.magic_6 = mdiv_6.magic; + karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); + karg.shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); + } +#endif + karg_size = sizeof(karg); + memcpy(static_cast(&karg_buffer[0]), static_cast(&karg), karg_size); + }else if(tunable->tensor_layout == "nhwc"){ + igemm_fwd_gtc_nhwc_karg_t karg; + karg.p_in = p_in; + karg.p_wei = p_wei; + karg.p_out = p_out; + karg.hi = hi; + karg.wi = wi; + karg.n = n; + karg.k = k / group; + karg.c = c / group; + karg.ho = ho; + karg.wo = wo; + karg.stride_h = stride_h; + karg.stride_w = stride_w; + karg.dilation_h = dilation_h; + karg.dilation_w = dilation_w; + karg.pad_h = pad_h; + karg.pad_w = pad_w; + karg.y = y; + karg.x = x; + karg.group = group; +#if USE_MAGIC_DIV + int gemm_m = n * ho * wo; + int gemm_n = k / group; + magic_div_u32_t mdiv_0 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); + magic_div_u32_t mdiv_1 = magic_div_u32_gen(ho*wo); + magic_div_u32_t mdiv_2 = magic_div_u32_gen(wo); + magic_div_u32_t mdiv_3 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); karg.magic_0 = mdiv_0.magic; karg.magic_1 = mdiv_1.magic; karg.magic_2 = mdiv_2.magic; karg.magic_3 = mdiv_3.magic; - karg.magic_4 = mdiv_4.magic; - karg.magic_5 = mdiv_5.magic; - karg.magic_6 = mdiv_6.magic; karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); - karg.shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); - } #endif + karg_size = sizeof(karg); + memcpy(static_cast(&karg_buffer[0]), static_cast(&karg), karg_size); + } else { + assert(0); + } - int block_size = get_block_size(tunable); - int grid_size = get_grid_size(arg, tunable); + size_t block_size = get_block_size(tunable); hipFunction_t kernel_func; std::string kernel_name = get_kernel_name(tunable); - // printf("kernel:%s\n, block:%d, grid:%d\n", kernel_name.c_str(), block_size, grid_size); - HIP_CALL( - hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); - - auto launch_fwd = [&]() -> float { - // printf("launch fwd block:%d, grid:%dx%d\n", block_size, grid_size, splits); - // dump_fwd_karg(&karg); - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; - -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - - timer.stop(); - ms = timer.duration(); -#endif - return ms; - }; - - for (int i = 0; i < warmup; i++) { - launch_fwd(); - } + HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); + + // TODO: use kernel to pre-clear when atomic + auto fwd_epilog = tunable->gemm_k_global_split ? + std::function{[&]() -> float{ + hipMemset(p_out, 0, static_cast(n) * k * ho * wo * sizeof(float)); + return .0; + }} : + std::function{[&]() -> float{ + return .0; + }}; - std::vector duration_list; - for (int i = 0; i < repeat; i++) { - float d = launch_fwd(); - duration_list.push_back(d); + result_t result; + result.kernel_name = kernel_name; + if(this->driver_mode == driver_mode_normal){ + float min_duration = FLT_MAX; + int selected_gks = 0; + int max_split_num = tunable->gemm_k_global_split == 0 ? + 0 : igemm_get_max_gks(c / group, tunable->gemm_k_per_block, MAX_GEMM_K_SPLITS); + for(int gks = 0; gks <= max_split_num; gks++){ + size_t grid_size = get_grid_size(arg, tunable) * (1 << gks); + if(tunable->tensor_layout == "nhwc"){ + // This is hacky, but in MIOpen we prefer a heuristic way to set gks, so ok now. + igemm_fwd_gtc_nhwc_karg_t *karg_revalue = (igemm_fwd_gtc_nhwc_karg_t *)(karg_buffer); + karg_revalue->ks = gks; + // dump_fwd_karg(karg_revalue); + // printf("block:%d, grid:%d\n", block_size, grid_size); + // fflush(stdout); + } + float duration = igemm_launch_kernels_with_epilog({ + {kernel_func, karg_buffer, karg_size, {grid_size * block_size, splits, 1}, {block_size, 1, 1}} + }, fwd_epilog, this->warmup, this->repeat); + + if(min_duration > duration){ + min_duration = duration; + selected_gks = gks; + } + } + + result.return_code = 0; + result.duration_ms = min_duration; + result.gks = selected_gks; + }else if(this->driver_mode == driver_mode_heuristic){ + int gks = heuristic_select_gks(arg, tunable); + size_t grid_size = get_grid_size(arg, tunable) * (1 << gks); + + float duration = igemm_launch_kernels_with_epilog({ + {kernel_func, karg_buffer, karg_size, {grid_size * block_size, splits, 1}, {block_size, 1, 1}} + }, fwd_epilog, this->warmup, this->repeat); + + result.return_code = 0; + result.duration_ms = duration; + result.gks = gks; + }else{ + assert(0); } - // remove min and max from list, then do average - auto imin = std::min_element(begin(duration_list), end(duration_list)); - duration_list.erase(imin); - auto imax = std::max_element(begin(duration_list), end(duration_list)); - duration_list.erase(imax); - assert(duration_list.size() == (repeat - 2)); - float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); - usleep(1000 * 5); - - result_t result; - result.return_code = 0; - result.duration_ms = avg_duration; - result.kernel_name = kernel_name; return result; } }; diff --git a/driver/igemm_gtc_base.h b/driver/igemm_gtc_base.h index e8215f01..00cdf1ef 100755 --- a/driver/igemm_gtc_base.h +++ b/driver/igemm_gtc_base.h @@ -34,6 +34,8 @@ #include #include #include +#include +#include #define IGEMM_GTC_TUNABLE_FMA_TYPE_MAC "mac" #define IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS "dlops" @@ -41,6 +43,19 @@ #define IGEMM_GTC_TUNABLE_FMA_TYPE_NA "fma_na" #define AMDGPU_WAVE_SIZE 64 +typedef enum { + driverHalf = 0, /*!< 16-bit floating point (Fully supported) */ + driverFloat = 1, /*!< 32-bit floating point (Fully supported) */ + driverInt8 = 3, + driverBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction) + (Partially supported) */ +} driverDataType_t; + +typedef enum { + driver_mode_normal = 0, // bench all solutions + driver_mode_heuristic = 1, // find suitable heuristic +} driver_mode_t; + #if USE_MAGIC_DIV typedef struct { uint32_t magic; @@ -123,6 +138,8 @@ typedef struct { int wave_tile_k; }; }; + int tensor_a_pass_through; + int tensor_b_pass_through; std::vector tensor_a_thread_lengths; std::vector tensor_a_cluster_lengths; std::vector tensor_b_thread_lengths; @@ -187,6 +204,8 @@ igemm_gtc_tunable_from_config(const config_content_t &content) { tunable.wave_repeat_n = sec.at("wave_repeat_n").get_int(); tunable.wave_tile_k = sec.count("wave_tile_k") > 0 ? sec.at("wave_tile_k").get_int() : 1; } + tunable.tensor_a_pass_through = sec.count("tensor_a_pass_through") > 0 ? sec.at("tensor_a_pass_through").get_int() : 0; + tunable.tensor_b_pass_through = sec.count("tensor_b_pass_through") > 0 ? sec.at("tensor_b_pass_through").get_int() : 0; tunable.tensor_a_thread_lengths = sec.at("tensor_a_thread_lengths").get_list_int(); tunable.tensor_a_cluster_lengths = sec.at("tensor_a_cluster_lengths").get_list_int(); tunable.tensor_b_thread_lengths = sec.at("tensor_b_thread_lengths").get_list_int(); @@ -222,6 +241,8 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { // auto gemm_n_per_thread = tunable->gemm_n_per_thread; // auto gemm_n_level0_cluster = tunable->gemm_n_level0_cluster; // auto gemm_n_level1_cluster = tunable->gemm_n_level1_cluster; + auto tensor_a_pass_through = tunable->tensor_a_pass_through; + auto tensor_b_pass_through = tunable->tensor_b_pass_through; auto tensor_a_thread_lengths = tunable->tensor_a_thread_lengths; auto tensor_a_cluster_lengths = tunable->tensor_a_cluster_lengths; auto tensor_b_thread_lengths = tunable->tensor_b_thread_lengths; @@ -296,6 +317,10 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { "tb" + utility_int_list_to_string(tensor_b_thread_lengths) + "_" + utility_int_list_to_string(tensor_b_cluster_lengths); // printf("[%s]\n",kernel_name.c_str()); + if(tensor_a_pass_through) + kernel_name += std::string("_pta"); + if(tensor_b_pass_through) + kernel_name += std::string("_ptb"); if(gemm_m_unmerge_cluster) kernel_name += std::string("_mc"); if(gemm_n_unmerge_cluster) @@ -310,4 +335,220 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { return kernel_name; } +static inline float igemm_launch_kernel_single(hipFunction_t kernel_func, void* args, size_t arg_size, std::vector grid_size, std::vector block_size) +{ + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, args, + HIP_LAUNCH_PARAM_BUFFER_SIZE, &arg_size, + HIP_LAUNCH_PARAM_END}; + float ms = .0; + + hipEvent_t start; + hipEvent_t stop; + hipEventCreate(&start); + hipEventCreate(&stop); + + // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size[0], grid_size[1], grid_size[2], + block_size[0], block_size[1], block_size[2], 0, 0, NULL, + (void **)&config, start, stop)); + + + hipEventSynchronize(stop); + hipEventElapsedTime(&ms, start, stop); + hipEventDestroy(start); + hipEventDestroy(stop); + + return ms; +} + +static inline float igemm_launch_kernel(hipFunction_t kernel_func, void* args, size_t arg_size, std::vector grid_size, std::vector block_size, int warmup, int repeat) +{ + assert(repeat > 2); + std::vector duration_list; + for (int i = 0; i < warmup; i++) { + igemm_launch_kernel_single(kernel_func, args, arg_size, grid_size, block_size); + } + + for (int i = 0; i < repeat; i++) { + float d = igemm_launch_kernel_single(kernel_func, args, arg_size, grid_size, block_size); + duration_list.push_back(d); + } + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + return avg_duration; +} + +typedef struct{ + hipFunction_t kernel_func; + void * args; + size_t arg_size; + std::vector grid_size; + std::vector block_size; +}igemm_launch_kernel_t; +static inline float igemm_launch_kernels(const std::vector & kernels, int warmup, int repeat) +{ + auto launch_kernels = [&]() -> float{ + float ms = .0; + for(const auto ker : kernels) + ms += igemm_launch_kernel_single(ker.kernel_func, ker.args, ker.arg_size, ker.grid_size, ker.block_size); + return ms; + }; + + assert(repeat > 2); + std::vector duration_list; + for (int i = 0; i < warmup; i++) { + launch_kernels(); + } + + for (int i = 0; i < repeat; i++) { + float d = launch_kernels(); + duration_list.push_back(d); + } + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + return avg_duration; +} +template +static inline float igemm_launch_kernels_with_epilog(const std::vector & kernels, epilog_kernel_t epilog_kernel, int warmup, int repeat) +{ + auto launch_kernels = [&]() -> float{ + float ms = .0; + ms += epilog_kernel(); + for(const auto & ker : kernels) + ms += igemm_launch_kernel_single(ker.kernel_func, ker.args, ker.arg_size, ker.grid_size, ker.block_size); + return ms; + }; + + assert(repeat > 2); + std::vector duration_list; + for (int i = 0; i < warmup; i++) { + launch_kernels(); + } + + for (int i = 0; i < repeat; i++) { + float d = launch_kernels(); + duration_list.push_back(d); + } + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + return avg_duration; +} + +static inline int igemm_get_max_gks(int gemm_k, int gemm_k_per_block, int max_log2_splits) +{ + if(gemm_k % gemm_k_per_block != 0) + return 0; + int rem = gemm_k / gemm_k_per_block; + // to find the highest power of 2 value that can divide rem + // https://www.geeksforgeeks.org/highest-power-of-two-that-divides-a-given-number/ + int rem_pow2 = rem & (~(rem - 1)); + int gks = (int)log2(rem_pow2); + if(gks > max_log2_splits) + gks = max_log2_splits; + return gks; +} + +// this is to support big tensor > 4G. need to decide how many splits needed +// return the number of splits +static inline size_t igemm_split_batch_size(const args_t *arg, int data_byte) +{ + int hi = arg->get_int("in_h"); + int wi = arg->get_int("in_w"); + int n = arg->get_int("batchsize"); + int k = arg->get_int("out_channels"); + int c = arg->get_int("in_channels"); + + int stride_h = arg->get_int("conv_stride_h"); + int stride_w = arg->get_int("conv_stride_w"); + int dilation_h = arg->get_int("dilation_h"); + int dilation_w = arg->get_int("dilation_w"); + int pad_h = arg->get_int("pad_h"); + int pad_w = arg->get_int("pad_w"); + int y = arg->get_int("fil_h"); + int x = arg->get_int("fil_w"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + // int data_byte = utility_string_to_data_byte(tunable->precision); + size_t image_size_input = static_cast(c) * hi * wi * data_byte; + size_t image_size_output = static_cast(k) * ho * wo * data_byte; + size_t size_4g = 0xffffffffUL; + if(image_size_input >= size_4g || image_size_output >= size_4g) + return 0; + + size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; + size_t splited_n = size_4g / image_size; + + // round up splits, we must match + // 1. splited_n * image_size < size_4g + // 2. n % splited_n == 0 + // if(splited_n >= n) + // return 1; + assert(splited_n != 0); + while(splited_n >= 1){ + // printf("n:%d, splited_n:%d\n", n, splited_n); + if(n % splited_n == 0) + break; + splited_n--; + } + + assert(splited_n * image_size < size_4g && n % splited_n == 0); + return static_cast(n) / splited_n; +} + +class igemm_driver_base_t{ +public: + igemm_driver_base_t(hipModule_t module_, driver_mode_t driver_mode_, driverDataType_t data_type_, int warmup_, int repeat_, bool verbose_) : + module(module_), driver_mode(driver_mode_), data_type(data_type_), warmup(warmup_), repeat(repeat_), verbose(verbose_) + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + this->num_cu = dev_prop.multiProcessorCount; + this->gcn_arch = dev_prop.gcnArch; + if(this->gcn_arch >= 1000) + this->num_cu *= 2; + } + std::string get_kernel_name(const igemm_gtc_tunable_t *tunable) { + return igemm_gtc_encode_kernel_name(tunable); + } + + virtual size_t get_block_size(const igemm_gtc_tunable_t *tunable) = 0; + virtual size_t get_grid_size(const args_t *arg, const igemm_gtc_tunable_t *tunable) = 0; + virtual bool tunable_is_valid(const args_t *arg, const igemm_gtc_tunable_t *tunable) = 0; + virtual result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, void *p_in, void *p_wei, void *p_out) = 0; + + virtual igemm_gtc_tunable_t heuristic_select_kernel(const args_t *arg) {return igemm_gtc_tunable_t{}; } + virtual int heuristic_select_gks(const args_t *arg, const igemm_gtc_tunable_t *tunable) {return 0; } + + hipModule_t module; + driver_mode_t driver_mode; + driverDataType_t data_type; + int warmup; + int repeat; + bool verbose; + + int num_cu; + int gcn_arch; +}; + #endif \ No newline at end of file diff --git a/driver/igemm_wrw_gtc_driver.h b/driver/igemm_wrw_gtc_driver.h index e63efae4..22010de1 100644 --- a/driver/igemm_wrw_gtc_driver.h +++ b/driver/igemm_wrw_gtc_driver.h @@ -38,9 +38,9 @@ #include typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -85,83 +85,13 @@ static void dump_wrw_karg(igemm_wrw_gtc_karg_t * karg){ std::cout<gemm_m_per_block; - auto gemm_n_per_block = tunable->gemm_n_per_block; - auto gemm_k_per_block = tunable->gemm_k_per_block; - auto gemm_m_per_thread = tunable->gemm_m_per_thread; - auto gemm_m_level0_cluster = tunable->gemm_m_level0_cluster; - auto gemm_m_level1_cluster = tunable->gemm_m_level1_cluster; - auto gemm_n_per_thread = tunable->gemm_n_per_thread; - auto gemm_n_level0_cluster = tunable->gemm_n_level0_cluster; - auto gemm_n_level1_cluster = tunable->gemm_n_level1_cluster; - auto tensor_a_thread_lengths = tunable->tensor_a_thread_lengths; - auto tensor_a_cluster_lengths = tunable->tensor_a_cluster_lengths; - auto tensor_b_thread_lengths = tunable->tensor_b_thread_lengths; - auto tensor_b_cluster_lengths = tunable->tensor_b_cluster_lengths; - auto direction = tunable->direction; - auto precision = tunable->precision; - auto nxb = tunable->nxb; - auto nxe = tunable->nxe; - auto gemm_m_unmerge_cluster = tunable->gemm_m_unmerge_cluster; - auto gemm_n_unmerge_cluster = tunable->gemm_n_unmerge_cluster; - auto gemm_k_unmerge_cluster = tunable->gemm_k_unmerge_cluster; - auto multihead = tunable->multihead; - - assert(gemm_m_per_block % (gemm_m_per_thread * gemm_m_level0_cluster * gemm_m_level1_cluster) == 0); - assert(gemm_n_per_block % (gemm_n_per_thread * gemm_n_level0_cluster * gemm_n_level1_cluster) == 0); - int gemm_m_repeat = gemm_m_per_block / (gemm_m_per_thread * gemm_m_level0_cluster * gemm_m_level1_cluster); - int gemm_n_repeat = gemm_n_per_block / (gemm_n_per_thread * gemm_n_level0_cluster * gemm_n_level1_cluster); - - int thread_tile_m = gemm_m_repeat * gemm_m_per_thread; - int thread_tile_n = gemm_n_repeat * gemm_n_per_thread; - - assert(direction == "wrw"); - - std::string kernel_prefix = std::string("igemm_") + direction + std::string("_gtc_") + precision + - std::string("_bx") + std::to_string(nxb) + - std::string("_ex") + std::to_string(nxe) + "_"; - std::string kernel_name = - kernel_prefix + - "bt" + - std::to_string(gemm_m_per_block) + "x" + - std::to_string(gemm_n_per_block) + "x" + - std::to_string(gemm_k_per_block) + "_" + - "tt" + - std::to_string(thread_tile_m) + "x" + - std::to_string(thread_tile_n) + "_" + - "gm" + - std::to_string(gemm_m_repeat) + "x" + - std::to_string(gemm_m_level0_cluster) + "x" + - std::to_string(gemm_m_level1_cluster) + "_" + - "gn" + - std::to_string(gemm_n_repeat) + "x" + - std::to_string(gemm_n_level0_cluster) + "x" + - std::to_string(gemm_n_level1_cluster) + "_" + - "ta" + utility_int_list_to_string(tensor_a_thread_lengths) + "_" + - utility_int_list_to_string(tensor_a_cluster_lengths)+ "_" + - "tb" + utility_int_list_to_string(tensor_b_thread_lengths) + "_" + - utility_int_list_to_string(tensor_b_cluster_lengths); - // printf("[%s]\n",kernel_name.c_str()); - if(gemm_m_unmerge_cluster) - kernel_name += std::string("_mc"); - if(gemm_n_unmerge_cluster) - kernel_name += std::string("_nc"); - if(gemm_k_unmerge_cluster) - kernel_name += std::string("_kc"); - if(multihead) - kernel_name += std::string("_mh"); - return kernel_name; -#else - return igemm_gtc_encode_kernel_name(tunable); -#endif - } - int get_block_size(const igemm_gtc_tunable_t *tunable) { + + size_t get_block_size(const igemm_gtc_tunable_t *tunable) override { if(tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_MAC || tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS){ return tunable->gemm_m_level0_cluster * tunable->gemm_n_level0_cluster * tunable->gemm_m_level1_cluster * tunable->gemm_n_level1_cluster; @@ -177,8 +107,8 @@ class igemm_wrw_gtc_t { return 0; } } - int get_grid_size(const args_t *arg, - const igemm_gtc_tunable_t *tunable) { + size_t get_grid_size(const args_t *arg, + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); int n = arg->get_int("batchsize"); @@ -224,7 +154,7 @@ class igemm_wrw_gtc_t { } bool tunable_is_valid(const args_t *arg, - const igemm_gtc_tunable_t *tunable) + const igemm_gtc_tunable_t *tunable) override { // TODO: int hi = arg->get_int("in_h"); @@ -634,12 +564,11 @@ class igemm_wrw_gtc_t { } result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, - hipModule_t module, float *p_in, float *p_wei, float *p_out, - int warmup, int repeat) { + void *p_in, void *p_wei, void *p_out) override { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; - std::cout << "not valid tunable config." << std::endl; + // std::cout << "not valid tunable config." << std::endl; return result; } @@ -698,8 +627,8 @@ class igemm_wrw_gtc_t { //printf("gemmk split is %d\r\n", 1 << gemm_k_global_split); - int block_size = get_block_size(tunable); - int grid_size = get_grid_size(arg, tunable); + size_t block_size = get_block_size(tunable); + size_t grid_size = get_grid_size(arg, tunable); hipFunction_t kernel_func; std::string kernel_name = get_kernel_name(tunable); @@ -710,89 +639,27 @@ class igemm_wrw_gtc_t { // hipMemset(p_wei, 0x0, group * (k / group) * (c / group) * y * x * sizeof(float)); - auto launch_wrw_driver = [&](){ - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; - - if(gemm_k_global_split){ - // TODO: current implementation of global split K need pre-clear the wei tensor - // This may not be true in the future! + auto wrw_epilog = gemm_k_global_split ? + std::function{[&]() -> float{ hipMemset(p_wei, 0x0, group * (k / group) * (c / group) * y * x * sizeof(float)); - } - -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - - timer.stop(); - ms = timer.duration(); -#endif - return ms; - }; - - for (int i = 0; i < warmup; i++) { - launch_wrw_driver(); - } - std::vector duration_list; - for (int i = 0; i < repeat; i++) { - float d = launch_wrw_driver(); - duration_list.push_back(d); - } + return .0; + }} : + std::function{[&]() -> float{ + return .0; + }}; - // for (int i = 0; i < warmup; i++) { - // hipMemset(p_wei, 0x0, k * c * y * x * sizeof(float)); - // launch_wrw_driver(); - // } - - // remove min and max from list, then do average - auto imin = std::min_element(begin(duration_list), end(duration_list)); - duration_list.erase(imin); - auto imax = std::max_element(begin(duration_list), end(duration_list)); - duration_list.erase(imax); - assert(duration_list.size() == (repeat - 2)); - float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); - - usleep(1000 * 1); - - // debug section of code -#if 0 - printf("workspace debug \r\n"); - float* gemmc_host_check = (float* )malloc((1 << gemm_k_global_split) * k * c * y * x * sizeof(float)); - hipMemcpy(gemmc_host_check, p_wei, k * c * y * x * sizeof(float), hipMemcpyDeviceToHost); - for (int i_check = 0; i_check < (0+block_size); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check, gemmc_host_check[i_check], ((int *)gemmc_host_check)[i_check]); - } - printf("workspace debug end \r\n"); -#endif result_t result; + float duration = igemm_launch_kernels_with_epilog({ + {kernel_func, &karg, karg_size, {grid_size * block_size, 1, 1}, {block_size, 1, 1}} + }, wrw_epilog, this->warmup, this->repeat); + result.return_code = 0; - result.duration_ms = avg_duration; + result.duration_ms = duration; + result.gks = gemm_k_global_split; result.kernel_name = kernel_name; + return result; } }; - #endif \ No newline at end of file diff --git a/igemm/algo/__init__.py b/igemm/algo/__init__.py index e7110ca7..1bedce04 100755 --- a/igemm/algo/__init__.py +++ b/igemm/algo/__init__.py @@ -32,6 +32,7 @@ from .igemm_bwd_gtc import * from .igemm_wrw_gtc import * from .igemm_fwd_gtc import * +from .igemm_fwd_gtc_nhwc import * from .igemm_upsampling_clear import * from .utility import * from .thread_mapping import * diff --git a/igemm/algo/global_memory.py b/igemm/algo/global_memory.py index 100b98b2..49a9d883 100755 --- a/igemm/algo/global_memory.py +++ b/igemm/algo/global_memory.py @@ -27,6 +27,21 @@ import sys from ..codegen import * +class inst_global_load_dword_t(object): + def __init__(self, dwords): + self.dwords = dwords + + def __call__(self, vdst, vaddr, saddr, offset = 0): + if self.dwords == 1: + return f"global_load_dword v[{vdst}], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + if self.dwords == 2: + return f"global_load_dwordx2 v[{vdst}:{vdst}+1], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + if self.dwords == 3: + return f"global_load_dwordx3 v[{vdst}:{vdst}+2], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + if self.dwords == 4: + return f"global_load_dwordx4 v[{vdst}:{vdst}+3], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + assert False + class inst_buffer_load_dword_t(object): ''' TODO: this implementation always offen ''' def __init__(self, dwords): @@ -92,6 +107,15 @@ def __init__(self): self.precision = 'fp32' # 'fp32', 'fp16', ... self.src_order = 0 # 0-d0xd1, 1-d1xd0 self.dst_order = 0 # 0-d0xd1, 1-d1xd0 + self.use_flag = 0 + self.bfe_flag = 0 + self.precache_vs_ptn = 0 # 0: d0 use sgpr precache, d1 use vgpr precache + # 1: d0 use vgpr precache, d1 use sgpr precache + # 2: d0 use vgpr precache, d1 use vgpr precache + # 3: d0 use sgpr precache, d1 use sgpr precache + # 4: .... maybe consider not using precache? + self.flag_merge_v = 0 # when flag on v_offset, flag and multiple load, or flag per load + class macro_igemm_2d_global_load_t(macro_base_t): # TODO: if need vectorize further LDS write, need shuffle dst gpr while load @@ -210,6 +234,8 @@ def __init__(self, mc, ctrl, inline = False): self.declare_arg("s_stride_d0") self.declare_arg("s_stride_d1") self.declare_arg("s_offset") + if self.ctrl.use_flag: + self.declare_arg("v_flag") def name(self): ctrl = self.ctrl @@ -315,6 +341,8 @@ def expr(self): i_soffset = 0 for i_d0 in range(ctrl.length_d0): for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_dst)}]") if i_d0 == 0 and i_d1 == 0: self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os()}", f"{self.s_ptr()}", 0, 0)) elif i_d0 == 0 and i_d1 == 1: @@ -324,6 +352,8 @@ def expr(self): else: self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os()}", f"{self.s_ptr()}", f"{self.s_offset()}+{i_soffset}", 0)) i_soffset += 1 + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") i_dst = i_dst + 1 elif ctrl.src_order == 1 and ctrl.dst_order == 0: @@ -353,6 +383,213 @@ def get_issues(self): n_d1 = ctrl.length_d1 // ctrl.vector_d1 return ctrl.length_d0 * n_d1 +class macro_igemm_2d_global_load_precache_voffset_t(macro_base_t): + ''' + not support src/dst order + ''' + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_2d_global_load_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.declare_arg("v_dst") + self.declare_arg("s_ptr") + self.declare_arg("s_os") + self.declare_arg("v_os") + if self.ctrl.use_flag: + self.declare_arg("v_flag") + if self.ctrl.bfe_flag: + self.declare_arg("v_tmp") + + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + if ctrl.vector_d1 == 4: + vec_str = 'v4' + elif ctrl.vector_d1 == 2: + vec_str = 'v2' + elif ctrl.vector_d1 == 1: + vec_str = 'v1' + else: + assert False + + return f".v_gld_{ctrl.length_d0}x{ctrl.length_d1}_{bits_str}_{vec_str}_precache_voffset" + + def expr(self): + ctrl = self.ctrl + assert ctrl.length_d1 % ctrl.vector_d1 == 0 + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + assert ctrl.precision == 'fp32', "TO BE supported" + buffer_load_dword = inst_buffer_load_dword_t(ctrl.vector_d1) + + i_cnt = 0 + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + if ctrl.bfe_flag: + self._emit(f"v_bfe_u32 v[{self.v_tmp()}], v[{self.v_flag()}], {i_cnt}, 1") + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_tmp()}]") + else: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_cnt)}]") + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_cnt*ctrl.vector_d1}", f"{self.v_os(i_cnt)}", f"{self.s_ptr()}", f"{self.s_os()}", 0)) + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + i_cnt += 1 + + def get_issues(self): + ctrl = self.ctrl + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + return ctrl.length_d0 * n_d1 + +class macro_igemm_2d_global_load_precache_vs_offset_t(macro_base_t): + # precache voffset for d0 dimension + # precache soffset for d1 dimension + # hence v_flag is along d0 dimension + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_2d_global_load_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.declare_arg("v_dst") + self.declare_arg("s_ptr") + self.declare_arg("v_os") + self.declare_arg("s_stride_d0") # can be None + self.declare_arg("s_stride_d1") + self.declare_arg("s_offset") + if self.ctrl.use_flag: + self.declare_arg("v_flag") + if self.ctrl.bfe_flag: + self.declare_arg("v_tmp") + + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + if ctrl.vector_d1 == 4: + vec_str = 'v4' + elif ctrl.vector_d1 == 2: + vec_str = 'v2' + elif ctrl.vector_d1 == 1: + vec_str = 'v1' + else: + assert False + + return f".v_gld_{ctrl.length_d0}x{ctrl.length_d1}_{bits_str}_{vec_str}_precache_vs_offset" + + def expr(self): + ctrl = self.ctrl + assert ctrl.length_d1 % ctrl.vector_d1 == 0 + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + assert ctrl.precision == 'fp32', "TO BE supported" + buffer_load_dword = inst_buffer_load_dword_t(ctrl.vector_d1) + + if ctrl.src_order == 0 and ctrl.dst_order == 0: + i_dst = 0 + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_d0)}]") + current_s_offset = 0 if i_d1 == 0 else (self.s_stride_d1() if i_d1 == 1 else self.s_offset(i_d1 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d0)}", f"{self.s_ptr()}", current_s_offset, 0)) + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + i_dst = i_dst + 1 + + else: + assert False + + def get_issues(self): + ctrl = self.ctrl + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + return ctrl.length_d0 * n_d1 + +class macro_igemm_2d_global_load_precache_sv_offset_t(macro_base_t): + # precache soffset for d0 dimension + # precache voffset for d1 dimension + # hence v_flag is along d1 dimension + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_2d_global_load_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.declare_arg("v_dst") + self.declare_arg("s_ptr") + self.declare_arg("v_os") + self.declare_arg("s_stride_d0") # can be None + self.declare_arg("s_stride_d1") + self.declare_arg("s_offset") + if self.ctrl.use_flag: + self.declare_arg("v_flag") + if self.ctrl.bfe_flag: + self.declare_arg("v_tmp") + + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + if ctrl.vector_d1 == 4: + vec_str = 'v4' + elif ctrl.vector_d1 == 2: + vec_str = 'v2' + elif ctrl.vector_d1 == 1: + vec_str = 'v1' + else: + assert False + + return f".v_gld_{ctrl.length_d0}x{ctrl.length_d1}_{bits_str}_{vec_str}_precache_sv_offset" + + def expr(self): + ctrl = self.ctrl + assert ctrl.length_d1 % ctrl.vector_d1 == 0 + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + assert ctrl.precision == 'fp32', "TO BE supported" + buffer_load_dword = inst_buffer_load_dword_t(ctrl.vector_d1) + + if ctrl.src_order == 0 and ctrl.dst_order == 0: + i_dst = 0 + if ctrl.flag_merge_v and n_d1 == 1: + # v is along d1 dimension, hence only possible when n_d1 is 1 + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag()}]") + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(1): + current_s_offset = 0 if i_d0 == 0 else (self.s_stride_d1() if i_d0 == 1 else self.s_offset(i_d0 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d1)}", f"{self.s_ptr()}", current_s_offset, 0)) + i_dst = i_dst + 1 + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + else: + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_d1)}]") + current_s_offset = 0 if i_d0 == 0 else (self.s_stride_d1() if i_d0 == 1 else self.s_offset(i_d0 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d1)}", f"{self.s_ptr()}", current_s_offset, 0)) + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + i_dst = i_dst + 1 + + else: + assert False + + def get_issues(self): + ctrl = self.ctrl + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + return ctrl.length_d0 * n_d1 + class macro_igemm_write_4d_strided_t(macro_base_t): ''' TODO: this is always not inline diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index 3d743e4f..7dbe82a4 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -172,6 +172,8 @@ def __init__(self, tunable_dict): else: assert False + self.tensor_a_pass_through = utility_dict_with_default_t(tunable_dict)('tensor_a_pass_through', 0) + self.tensor_b_pass_through = utility_dict_with_default_t(tunable_dict)('tensor_b_pass_through', 0) self.tensor_a_thread_lengths = tunable_dict['tensor_a_thread_lengths'] # list! self.tensor_a_cluster_lengths = tunable_dict['tensor_a_cluster_lengths'] # list! self.tensor_b_thread_lengths = tunable_dict['tensor_b_thread_lengths'] # list! @@ -185,7 +187,7 @@ def __init__(self, tunable_dict): self.allow_lds_reorder = utility_dict_with_default_t(tunable_dict)('allow_lds_reorder', IGEMM_GTC_FEAT_ALLOW_LDS_REORDER) self.precache_soffset = utility_dict_with_default_t(tunable_dict)('precache_soffset', IGEMM_GTC_FEAT_PRECACHE_SOFFSET) - default_source_access_order = IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M if self.direction == 'fwd' \ + default_source_access_order = IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M if (self.direction == 'fwd' and self.tensor_layout == 'nchw') \ else IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N self.source_access_order = utility_dict_with_default_t(tunable_dict)('source_access_order', default_source_access_order) @@ -201,7 +203,12 @@ def __init__(self, tunable_dict): # assert type(self.opt_1x1) is bool assert self.direction in ('fwd', 'bwd', 'wrw') assert self.precision in ('fp32', 'fp16', 'bf16') - assert self.nxb in (1,4,8,16,32,64,128,256) + if self.tensor_layout == "nchw": + assert self.nxb in (1,4,8,16,32,64,128,256) + elif self.tensor_layout == "nhwc": + assert self.nxb == 0, 'nhwc now no need have different nxb value' + else: + assert False assert self.nxe in (0,1) # TODO: better specify @@ -226,10 +233,17 @@ def _unmerge_x1_from_e(unroll_k, nxe): return unroll_k # not used if self.direction == 'fwd': - assert self.gemm_n_per_block % self.nxb == 0 - self.unmerge_sub_n = self.gemm_n_per_block // self.nxb - self.unmerge_sub_k = 1 # not used - self.unmerge_sub_c = _unmerge_x1_from_e(self.gemm_k_per_block, self.nxe) + if self.tensor_layout == 'nchw': + assert self.gemm_n_per_block % self.nxb == 0 + self.unmerge_sub_n = self.gemm_n_per_block // self.nxb + self.unmerge_sub_k = 1 # not used + self.unmerge_sub_c = _unmerge_x1_from_e(self.gemm_k_per_block, self.nxe) + elif self.tensor_layout == 'nhwc': + self.unmerge_sub_n = 1 # not used + self.unmerge_sub_k = 1 # not used + self.unmerge_sub_c = 1 # not used + else: + assert False elif self.direction == 'bwd': assert self.gemm_n_per_block % self.nxb == 0 self.unmerge_sub_n = self.gemm_n_per_block // self.nxb @@ -241,6 +255,9 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.unmerge_sub_k = 1 self.unmerge_sub_c = self.gemm_n_per_block + self.tensor_a_pass_through_interleave_gld = 0 if self.tensor_layout == 'nhwc' else 1 + self.tensor_b_pass_through_interleave_gld = 0 if self.tensor_layout == 'nhwc' else 1 + self.fma_interleave = IGEMM_GTC_FEAT_FMA_INTERLEAVE self.local_prefetch_num = 1 # vector global/lds implicit here @@ -254,6 +271,8 @@ def _unmerge_x1_from_e(unroll_k, nxe): elif self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self.local_prefetch_num = 2 if IGEMM_GTC_FEAT_LOCAL_PREFETCH else 1 + if (self.tensor_a_pass_through and self.wave_repeat_n == 2) or (self.tensor_b_pass_through and self.wave_repeat_m == 2): + self.local_prefetch_num = 1 # register for a,b,c buffer xdlops_mapping = get_ctrl_xdlops_mapping_fp32(self.gemm_m_per_block, self.gemm_n_per_block, self.block_size // amdgpu_wave_size(tunable_dict['arch'])) self.num_agpr_accumulate_c = xdlops_mapping.total_acc_c() @@ -261,6 +280,9 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.num_vgpr_accumulate_a = self.wave_step_m * self.wave_repeat_m * xdlops_mapping.inst_mfma.num_v_a * self.local_prefetch_num self.num_vgpr_accumulate_b = self.wave_step_n * self.wave_repeat_n * xdlops_mapping.inst_mfma.num_v_b * self.local_prefetch_num + self.global_prefetch_a_num = 2 if self.tensor_a_pass_through and not self.tensor_a_pass_through_interleave_gld else 1 + self.global_prefetch_b_num = 2 if self.tensor_b_pass_through and not self.tensor_b_pass_through_interleave_gld else 1 + self.num_vgpr_global_load_a = igemm_flatten_list_product(self.tensor_a_thread_lengths) self.num_vgpr_global_load_b = igemm_flatten_list_product(self.tensor_b_thread_lengths) @@ -268,13 +290,13 @@ def _unmerge_x1_from_e(unroll_k, nxe): assert self.num_vgpr_global_load_b * self.block_size == self.gemm_n_per_block * self.gemm_k_per_block # LDS size - self.lds_a = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_m_per_block - self.lds_b = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_n_per_block - self.lds_a_np2 = igemm_next_pow2( self.lds_a) - self.lds_b_np2 = igemm_next_pow2( self.lds_b) - self.lds_single = igemm_next_pow2( self.lds_a_np2 + self.lds_b_np2) - self.lds_buffer_num = 1 if self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS else 2 - self.lds_total = self.lds_buffer_num * self.lds_single + self.lds_a = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_m_per_block if not self.tensor_a_pass_through else 0 + self.lds_b = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_n_per_block if not self.tensor_b_pass_through else 0 + self.lds_a_np2 = igemm_next_pow2( self.lds_a) if self.lds_a != 0 else 0 + self.lds_b_np2 = igemm_next_pow2( self.lds_b) if self.lds_b != 0 else 0 + self.lds_single = igemm_next_pow2( self.lds_a_np2 + self.lds_b_np2) if (self.lds_a_np2 + self.lds_b_np2 != 0) else 0 + self.lds_buffer_num = 1 if self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS else 2 + self.lds_total = self.lds_buffer_num * self.lds_single # print(f"lds_a:{self.lds_a}, lds_b:{self.lds_b}, lds_a_np2:{self.lds_a_np2}, lds_b_np2:{self.lds_b_np2}, lds_single:{self.lds_single}, lds_total:{self.lds_total}") # TODO: LDS size check @@ -286,12 +308,15 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.thread_sub_tile_n = self.gemm_n_per_thread # number of loops at least needed for final coalescing store, dicided by LDS size - self.coalescing_store_groups = (self.gemm_m_per_block * self.gemm_n_per_block) // \ - (self.lds_buffer_num * igemm_next_pow2(igemm_next_pow2(self.gemm_k_per_block * self.gemm_m_per_block) + igemm_next_pow2(self.gemm_k_per_block * self.gemm_n_per_block) )) + # self.coalescing_store_groups = (self.gemm_m_per_block * self.gemm_n_per_block) // \ + # (self.lds_buffer_num * igemm_next_pow2(igemm_next_pow2(self.gemm_k_per_block * self.gemm_m_per_block) + igemm_next_pow2(self.gemm_k_per_block * self.gemm_n_per_block) )) + self.coalescing_store_groups = (self.gemm_m_per_block * self.gemm_n_per_block) // (self.lds_total // amdgpu_precision_data_byte(self.precision)) + if self.coalescing_store_groups == 0: self.coalescing_store_groups = 1 # this means LDS size is already bigger than c matrix all pixel. just use one group is ok #if self.coalescing_store_groups < 2: # self.coalescing_store_groups = 2 + shrinked_lds_buffer_num = self.lds_buffer_num if self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # check on grouping xdlops_mapping = get_ctrl_xdlops_mapping_fp32(self.gemm_m_per_block, self.gemm_n_per_block, self.block_size // amdgpu_wave_size(tunable_dict['arch'])) @@ -302,8 +327,8 @@ def _unmerge_x1_from_e(unroll_k, nxe): shrink_in_co_group = self.coalescing_store_groups // length_in_m # TODO: this may affect occupancy! - self.lds_buffer_num = self.lds_buffer_num * shrink_in_co_group - self.lds_total = self.lds_buffer_num * self.lds_single + shrinked_lds_buffer_num = shrinked_lds_buffer_num * shrink_in_co_group + self.lds_total = shrinked_lds_buffer_num * self.lds_single self.coalescing_store_groups = self.coalescing_store_groups // shrink_in_co_group def output(self): @@ -344,6 +369,8 @@ def to_dict(self): tunable_dict['wave_tile_k'] = self.wave_tile_k else: assert False + tunable_dict['tensor_a_pass_through'] = self.tensor_a_pass_through + tunable_dict['tensor_b_pass_through'] = self.tensor_b_pass_through tunable_dict['tensor_a_thread_lengths'] = self.tensor_a_thread_lengths tunable_dict['tensor_a_cluster_lengths'] = self.tensor_a_cluster_lengths tunable_dict['tensor_b_thread_lengths'] = self.tensor_b_thread_lengths @@ -359,6 +386,8 @@ def to_dict(self): tunable_dict['precache_soffset'] = self.precache_soffset tunable_dict['local_prefetch_num'] = self.local_prefetch_num + tunable_dict['global_prefetch_a_num'] = self.global_prefetch_a_num + tunable_dict['global_prefetch_b_num'] = self.global_prefetch_b_num tunable_dict['fma_interleave'] = self.fma_interleave tunable_dict['gemm_m_unmerge_cluster'] = self.gemm_m_unmerge_cluster @@ -405,6 +434,12 @@ def get_dict_with_default(some_dict, key, default_value): line_start + 'wave_step_n {} {}'.format(equal, self.wave_step_n) + new_line + \ line_start + 'wave_repeat_n {} {}'.format(equal, self.wave_repeat_n) + new_line + \ line_start + 'wave_tile_k {} {}'.format(equal, self.wave_tile_k) + new_line + if self.tensor_a_pass_through: + sstr += \ + line_start + 'tensor_a_pass_through {} {}'.format(equal, self.tensor_a_pass_through) + new_line + if self.tensor_b_pass_through: + sstr += \ + line_start + 'tensor_b_pass_through {} {}'.format(equal, self.tensor_b_pass_through) + new_line sstr += \ line_start + 'tensor_a_thread_lengths {} {}'.format(equal, self.tensor_a_thread_lengths) + new_line + \ line_start + 'tensor_a_cluster_lengths {} {}'.format(equal, self.tensor_a_cluster_lengths) + new_line + \ @@ -426,6 +461,7 @@ def get_dict_with_default(some_dict, key, default_value): line_start + 'thread_tile {} {}x{}'.format(equal, self.thread_tile_m, self.thread_tile_n) + new_line sstr += \ line_start + 'lds_total {} {}'.format(equal, self.lds_total) + new_line + \ + line_start + 'lds_buffer_num {} {}'.format(equal, self.lds_buffer_num) + new_line + \ line_start return sstr @@ -464,6 +500,12 @@ def lengths_str(lengths): kernel_name += "ta" + lengths_str(tunable.tensor_a_thread_lengths) + "_" + lengths_str(tunable.tensor_a_cluster_lengths) + "_" +\ "tb" + lengths_str(tunable.tensor_b_thread_lengths) + "_" + lengths_str(tunable.tensor_b_cluster_lengths) + if tunable.tensor_a_pass_through: + kernel_name += "_pta" + + if tunable.tensor_b_pass_through: + kernel_name += "_ptb" + if tunable.gemm_m_unmerge_cluster: kernel_name += "_mc" diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py new file mode 100755 index 00000000..fab8d3a0 --- /dev/null +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -0,0 +1,1804 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2020-2021 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +# pylint: disable=maybe-no-member +from ..codegen import * +from .fma_main_loop import * +from .igemm_base import * +from .global_memory import * +from .shared_memory import * +from .utility import * +from .thread_mapping import * +from .xdlops_mapping import * +from .coalescing_store import * +from .mfma_main_loop import * + +IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG = 0 +# IGEMM_FWD_GTC_NHWC_P_INTERLEAVE_GLD = False # p tensor interleave + +def _find_non_1_index_in_list(list_object): + result_list = list() + for idx, item in enumerate(list_object): + assert type(item) is int + if item != 1: + result_list.append(idx) + return result_list + +class igemm_fwd_gtc_nhwc_t(mc_base_t): + ''' + tensor a (input) tensor b (wei) + thread_lengths : ta_e, ta_c, ta_nb0, ta_nb1, tb_e, tb_c, tb_k0, tb_k1 + cluster_lengths : ca_e, ca_c, ca_nb0, ca_nb1, cb_e, cb_c, cb_k0, cb_k1 + + for a/b tensor, always load gemm_k dimension first. + + ''' + def __init__(self, mc, tunable): + assert type(tunable) is igemm_gtc_tunable_parameter_t + mc_base_t.__init__(self, mc) + self.tunable = tunable + self.global_load_in = self.global_load_in_t(mc, self) + self.global_load_wei = self.global_load_wei_t(mc, self) + self.shared_store_in = self.shared_store_in_t(mc, self) + self.shared_store_wei = self.shared_store_wei_t(mc, self) + + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + self.in_thread_copy_ndim = len(in_thread_copy_index) + self.wei_thread_copy_ndim = len(wei_thread_copy_index) + assert self.in_thread_copy_ndim in (0, 1, 2) + assert self.wei_thread_copy_ndim in (0, 1, 2) + + self.coalescing_store_groups = igemm_next_pow2(self.tunable.coalescing_store_groups) + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + assert (self.tunable.gemm_m_per_thread * self.tunable.gemm_m_repeat) % self.coalescing_store_groups == 0, \ + f"coalescing store groups should be divided by thread m {self.tunable.gemm_m_per_thread}x{self.tunable.gemm_m_repeat}" + + ctrl_thread_mapping = ctrl_thread_mapping_t() + # -> MR x NR x ML1 x NL1 x ML0 x NL0 + ctrl_thread_mapping.thread_lengths = [self.tunable.gemm_m_repeat, self.tunable.gemm_n_repeat, 1, 1, self.tunable.gemm_m_per_thread, self.tunable.gemm_n_per_thread] + ctrl_thread_mapping.cluster_lengths = [1, 1, self.tunable.gemm_m_level1_cluster, self.tunable.gemm_n_level1_cluster, self.tunable.gemm_m_level0_cluster, self.tunable.gemm_n_level0_cluster] + self.thread_mapping = igemm_thread_mapping_t(self.mc, ctrl_thread_mapping) + + ctrl_coalescing_store = ctrl_coalescing_store_t() + ctrl_coalescing_store.ctm = ctrl_thread_mapping + ctrl_coalescing_store.coalescing_groups = self.coalescing_store_groups + ctrl_coalescing_store.data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + ctrl_coalescing_store.vector_write_out = 1 # TODO: some cases this can be set to other value + ctrl_coalescing_store.block_size = self.tunable.block_size + + gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + na_c0, na_c1e, na_k0, na_k1, nb_c0, nb_c1e, nb_n0, nb_n1b = self.get_dims_lengths() + ctrl_coalescing_store.gemm_m_m0_m1 = [na_k0, na_k1] + if gemm_m_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_M_K1_K0: + ctrl_coalescing_store.gemm_m_order = IGEMM_COALESCING_GEMM_M_ORDER_M1_M0 + + ctrl_coalescing_store.adjust_optimal_coalescing_groups() # in m1_m0 order, must adjust + self.coalescing_store = igemm_coalescing_store_t(mc, ctrl_coalescing_store) + + else: + def flatten(x): + from functools import reduce + return reduce(lambda a, b: a*b, x, 1) + ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile_fp32(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE) + self.xdlops_mapping = igemm_xdlops_mapping_t(self.mc, ctrl_xdlops_mapping) + assert flatten(ctrl_xdlops_mapping.acc_c_per_thread_m()) % self.coalescing_store_groups == 0, \ + f"coalescing store groups should be divided by agpr per thread in m direction {ctrl_xdlops_mapping.acc_c_per_thread_m()}" + + ctrl_coalescing_store_xdlops = ctrl_coalescing_store_xdlops_t() + ctrl_coalescing_store_xdlops.cxm = ctrl_xdlops_mapping + ctrl_coalescing_store_xdlops.gemm_k_global_split = self.tunable.gemm_k_global_split + ctrl_coalescing_store_xdlops.coalescing_groups = self.coalescing_store_groups + ctrl_coalescing_store_xdlops.data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + ctrl_coalescing_store_xdlops.vector_write_out = 1 # TODO: some cases this can be set to other value + ctrl_coalescing_store_xdlops.block_size = self.tunable.block_size + + # gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() + ctrl_coalescing_store_xdlops.gemm_m_m0_m1 = [na_nb0, na_nb1] + #if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0: + # # we may consider not suppor this mode + # ctrl_coalescing_store_xdlops.gemm_m_order = IGEMM_COALESCING_GEMM_M_ORDER_M1_M0 + ctrl_coalescing_store_xdlops.adjust_optimal_coalescing_groups() # in m1_m0 order, must adjust + self.coalescing_store = igemm_coalescing_store_xdlops_t(mc, ctrl_coalescing_store_xdlops) + + self.label_out = f"L_{self.name()}_out" + self.dict_shifted_stride = dict() + + self.karg = self.kernel_karg_t(mc, self) + self.sgpr = self.kernel_sgpr_t(mc, self) + self.vgpr = self.kernel_vgpr_t(mc, self) + if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self.agpr = self.kernel_agpr_t(mc, self) + + def name(self): + return igemm_gtc_encode_kernel_name(self.tunable) + + def try_shift_stride(self, gpr, shifter): + assert type(gpr) is sym_t + with self._deferred_context(): + if gpr.label not in self.dict_shifted_stride: + self.dict_shifted_stride[gpr.label] = gpr + self._emit(f"s_lshl_b32 s[{gpr()}], s[{gpr()}], {shifter}") + return self._get_deferred() + + class macro_set_flag_nhw(macro_base_t): + def __init__(self, mc, inline = False): + macro_base_t.__init__(self, mc, inline) + self.declare_arg("v_flag") + self.declare_arg("v_flag_n") + self.declare_arg("v_ih") + self.declare_arg("v_iw") + self.declare_arg("s_h") + self.declare_arg("s_w") + def name(self): + return '.v_fwd_gtc_nhwc_set_flag_nhw' + + def expr(self): + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_h()}], v[{self.v_ih()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag_n()}], vcc") + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_w()}], v[{self.v_iw()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag()}], vcc") + + class macro_move_slice_window_block_wise_1x1_t(macro_base_t): + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + if tunable.tensor_a_pass_through: + self.declare_arg("s_in_base") # 64bit acc + else: + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + self.declare_arg("v_wei_os") + self.declare_arg("s_move_slice_k_stride_c") # this is indeed gemm_k * data_byte, same for input/weight + self.options = options + + def name(self): + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_1x1_{self.tunable.tensor_a_pass_through}_{self.tunable.tensor_b_pass_through}' + + def expr(self): + if self.tunable.tensor_a_pass_through: + self._emit(f"s_add_u32 s[{self.s_in_base()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_base()}]") + self._emit(f"s_addc_u32 s[{self.s_in_base(1)}], 0, s[{self.s_in_base(1)}]") + else: + self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") + self._emit_empty_line() + + class macro_move_slice_window_block_wise_t(macro_base_t): + ''' + nhwc gemm_k = e*c, and thread/cluster length for e is always 1 + hence always move along c and accumulate into e + + this macro is for input and weight together. + block-wise move slice window, means we increase y*x*c using sgpr. + Indeed this is always true, since gemm_k % k_per_block == 0 always true. + Beside, we always increase along c dimension, this means y, x, c using sgpr is enough + + ''' + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + if tunable.tensor_a_pass_through: + self.declare_arg("s_in_base") # 64bit acc + self.declare_arg("s_in_c_itr") + else: + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + self.declare_arg("v_wei_os") + self.declare_arg("s_move_slice_k_stride_c") # this is indeed gemm_k * data_byte, same for input/weight + self.declare_arg("s_gemm_k_num_c") # c * data_byte + self.declare_arg("s_flag_need_acc_yx") + self.options = options + + def name(self): + return f'.v_fwd_gtc_nhwc_move_slice_window_block_wise_{self.tunable.tensor_a_pass_through}_{self.tunable.tensor_b_pass_through}' + + def expr(self): + if self.tunable.tensor_a_pass_through: + self._emit(f"s_add_u32 s[{self.s_in_base()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_base()}]") + self._emit(f"s_addc_u32 s[{self.s_in_base(1)}], 0, s[{self.s_in_base(1)}]") + else: + self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_add_u32 s[{self.s_in_c_itr()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_c_itr()}]") + self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_c_itr()}]") + else: + self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_offset()}]") + if not self.tunable.tensor_a_pass_through and not self.tunable.tensor_b_pass_through: + self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 1, 0") + self._emit_empty_line() + + class macro_move_slice_window_block_wise_acc_yx_t(macro_base_t): + ''' + can not inline + prefer to put this before global load wait. And for simplicity, no auto schedule. + ''' + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + if tunable.tensor_a_pass_through: + self.declare_arg("s_in_base") + self.declare_arg("s_in_c_itr") # + #self.declare_arg("s_gemm_k_num_c") # used to U64 sub s_in_base, can be None + else: + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + self.declare_arg("v_wei_os") + self.declare_arg("s_c") + self.declare_arg("s_gemm_k_num_c") + self.declare_arg("v_in_os") + self.declare_arg("v_in_ihi_list") + self.declare_arg("v_in_iwi_list") + self.declare_arg("v_in_flag") + if not IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self.declare_arg("v_in_flag_n") + self.declare_arg("s_flag_need_acc_yx") + self.declare_arg("s_move_slice_k_ix") + self.declare_arg("s_x") + self.declare_arg("s_in_diff_hi") # this is s_dilation_h * s_in_stride_hi - (x - 1) * s_dilation_w * s_in_stride_wi, always possitive + self.declare_arg("s_in_diff_wi") # this is s_dilation_w * s_in_stride_wi + self.declare_arg("s_dilation_h") + self.declare_arg("s_dilation_w") + self.declare_arg("s_dilation_w_x") # this is -1* (x - 1) * s_dilation_w + self.declare_arg("s_hi") + self.declare_arg("s_wi") + self.declare_arg("v_tmp") # 2 needed + self.declare_arg("s_tmp") + self.options = options + def name(self): + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_acc_yx' + + def expr(self): + assert "label_acc_yx" in self.options + label_acc_yx = self.options["label_acc_yx"] + '_{}'.format(self.expr_cnt) + label_acc_yx_end = self.options["label_acc_yx"] + '_end' + '_{}'.format(self.expr_cnt) + label_acc_yx_x_end = self.options["label_acc_yx"] + '_x_end' + '_{}'.format(self.expr_cnt) + + assert "nb_per_thread" in self.options + nb_per_thread = self.options["nb_per_thread"] + + assert 'm_set_flag_nhw' in self.options + m_set_flag_nhw = self.options['m_set_flag_nhw'] + + if not self.tunable.tensor_a_pass_through and not self.tunable.tensor_b_pass_through: + self._emit(f"s_cmp_eq_u32 1, s[{self.s_flag_need_acc_yx()}]") + self._emit(f"s_cbranch_scc0 {label_acc_yx_end} ; no need do accumulate yx") + self._emit_front(f"{label_acc_yx}:") + # wei os need to add a whole c when yx is changing + self._emit(f"v_add_u32 v[{self.v_wei_os()}], v[{self.v_wei_os()}], s[{self.s_c()}]") + self._emit(f"v_sub_u32 v[{self.v_wei_os()}], v[{self.v_wei_os()}], s[{self.s_gemm_k_num_c()}]") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_sub_u32 s[{self.s_in_base()}], s[{self.s_in_base()}], s[{self.s_gemm_k_num_c()}]") + self._emit(f"s_subb_u32 s[{self.s_in_base(1)}], s[{self.s_in_base(1)}], 0") + self._emit(f"s_mov_b32 s[{self.s_in_c_itr()}], 0") # reset input offset. wei, no care + else: + self._emit(f"s_mov_b32 s[{self.s_in_offset()}], 0") # reset input offset. wei, no care + ''' + ix accumulate, will only accumulate in width, and will never carry on to height + iy accumulate, will only accumulate in height, and will never carry on to batch + this makes life easier + ''' + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + self._emit(f"s_add_u32 s[{self.s_move_slice_k_ix()}], 1, s[{self.s_move_slice_k_ix()}]") + self._emit(f"s_cmp_le_u32 s[{self.s_x()}], s[{self.s_move_slice_k_ix()}]") + + # update iwi + self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_dilation_w_x()}], s[{self.s_dilation_w()}]") + for i in range(nb_per_thread): + self._emit(f"v_add_u32 v[{self.v_in_iwi_list(i)}], s[{self.s_tmp()}], v[{self.v_in_iwi_list(i)}]") + + # update in_os + self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_in_diff_hi()}], s[{self.s_in_diff_wi()}]") + for i in range(nb_per_thread): + self._emit(f"v_add_u32 v[{self.v_in_os(i)}], s[{self.s_tmp()}], v[{self.v_in_os(i)}]") + + # update ihi, accumulate + self._emit(f"s_cbranch_scc0 {label_acc_yx_x_end}") + self._emit(f"s_mov_b32 s[{self.s_move_slice_k_ix()}], 0") + for i in range(nb_per_thread): + self._emit(f"v_add_i32 v[{self.v_in_ihi_list(i)}], s[{self.s_dilation_h()}], v[{self.v_in_ihi_list(i)}]") + self._emit_front(f"{label_acc_yx_x_end}:") + + # now set flags + for i in range(nb_per_thread): + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(f"v_bfe_u32 v[{self.v_tmp(1)}], v[{self.v_in_flag()}], {16 + i}, 1 ; extract flag_n") + self._emit(f"v_and_b32 v[{self.v_in_flag()}], {0xffffffff ^ (1< 0 ) else 0 + + self.s_in_offset = sym_t("s_in_offset" , sseq(in_npc)) + self.s_in_c_itr = sym_t("s_in_c_itr" , 2) + self.s_in_stride_k_pack = sym_t("s_in_stride_k_pack" , sseq(1)) + else: + self.s_in_offset = sym_t("s_in_offset" , sseq(1)) + if outer.tunable.precache_soffset: + m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() + wei_npc = m_wei_2d_global_load.get_num_precache_soffset() + self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) + + # TODO: this sgpr allocation is a mess + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + # allocate several sgpr to hold magic/shift value. + self.s_magic_0 = sym_t("s_magic_0" ,self.s_p_in.value + 2) + self.s_magic_1 = sym_t("s_magic_1" ,self.s_p_in.value + 3) + self.s_magic_2 = sym_t("s_magic_2" ,self.s_p_out.value + 2) + self.s_magic_3 = sym_t("s_magic_3" ,self.s_p_out.value + 3) + self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_flag_need_acc_yx.value) + + self.s_gemmk_split = sym_t("s_gemmk_split" ,sseq(1)) + self.s_sub_c = sym_t("s_sub_c" ,sseq(1)) + self.s_tmp = sym_t("s_tmp" ,sseq(6, 2)) + self.s_end = sym_t("s_end" ,sseq()) + + def get_count(self): + return self.s_end.value + + def emit(self): + assert self.s_end.value <= amdgpu_sgpr_limit(self.mc.arch_config.arch), f"s_end:{self.s_end.value}, tunable:{self.outer.tunable.serialize()}" + for k, v in self.__dict__.items(): + if k.startswith('s_'): + self._emit(v.declare()) + + class kernel_vgpr_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = outer.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = outer.get_cluster_lengths() + + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + nk_per_thread = tb_k0 if tb_k0 != 1 else tb_k1 + assert nb_per_thread <= 16, "we pack flag into single vgpr" + + k_pack = outer.get_k_pack() + share_load_packed = k_pack if outer.tunable.tensor_a_pass_through or outer.tunable.tensor_b_pass_through else 1 + + is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS + vseq = gpr_sequencer_t() + num_vgpr_acc_a = share_load_packed * outer.tunable.num_vgpr_accumulate_a if not outer.tunable.tensor_a_pass_through else 0 + num_vgpr_acc_b = share_load_packed * outer.tunable.num_vgpr_accumulate_b if not outer.tunable.tensor_b_pass_through else 0 + if is_vgpr_acc_c: + self.v_c = sym_t("v_c" ,vseq(outer.tunable.num_vgpr_accumulate_c)) + v_c_num = vseq() + else: + v_c_resuable_num = num_vgpr_acc_a + num_vgpr_acc_b + \ + outer.tunable.num_vgpr_global_load_a + outer.tunable.num_vgpr_global_load_b + \ + 16 # from v_sst_a_os to v_co_sst + v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups + v_c_needed = (v_c_coalescing_num - v_c_resuable_num) if (v_c_coalescing_num - v_c_resuable_num) > 0 else 0 + + v_c_needed = v_c_needed if v_c_needed > 0 else 0 # let at least 0 + self.v_c = sym_t("v_c" ,vseq(v_c_needed), f"coalescing:{v_c_coalescing_num}, needed:{v_c_needed}, resuable:{v_c_resuable_num}") + + if not outer.tunable.tensor_a_pass_through: + self.v_a = sym_t("v_a" ,vseq(num_vgpr_acc_a)) + if not outer.tunable.tensor_b_pass_through: + self.v_b = sym_t("v_b" ,vseq(num_vgpr_acc_b)) + self.v_gld_a = sym_t("v_gld_a" ,vseq(outer.tunable.num_vgpr_global_load_a)) + if outer.tunable.global_prefetch_a_num == 2: + self.v_gld_a_gpf = sym_t("v_gld_a_gpf" ,vseq(outer.tunable.num_vgpr_global_load_a)) + self.v_gld_b = sym_t("v_gld_b" ,vseq(outer.tunable.num_vgpr_global_load_b)) + if outer.tunable.global_prefetch_b_num == 2: + self.v_gld_b_gpf = sym_t("v_gld_b_gpf" ,vseq(outer.tunable.num_vgpr_global_load_b)) + if not outer.tunable.tensor_a_pass_through: + self.v_sst_a_os = sym_t("v_sst_a_os" ,vseq(1)) + self.v_sld_a_os = sym_t("v_sld_a_os" ,vseq(1)) + if not outer.tunable.tensor_b_pass_through: + self.v_sst_b_os = sym_t("v_sst_b_os" ,vseq(1)) + self.v_sld_b_os = sym_t("v_sld_b_os" ,vseq(1)) + + self.v_in_os = sym_t("v_in_os" ,vseq(nb_per_thread)) + self.v_in_ihi_list = sym_t("v_in_ihi_list" ,vseq(nb_per_thread)) + self.v_in_iwi_list = sym_t("v_in_iwi_list" ,vseq(nb_per_thread)) + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self.v_in_flag = sym_t("v_in_flag" ,vseq(1)) # bfe this!, hi 16bit is flag_n, lo 16 bit is pure flag + else: + self.v_in_flag = sym_t("v_in_flag" ,vseq(nb_per_thread)) + self.v_in_flag_n = sym_t("v_in_flag_n" ,vseq(1)) # bfe this!, lo 16bit is flag_n + + self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) + self.v_out_os = sym_t("v_out_os" ,vseq(1)) + + if outer.tunable.tensor_a_pass_through: + self.v_gtc_ic_a = sym_t("v_gtc_ic_a" ,self.v_gld_a.value) + if outer.tunable.tensor_b_pass_through: + self.v_gtc_ic_b = sym_t("v_gtc_ic_b" ,self.v_gld_b.value) + if not (outer.tunable.tensor_a_pass_through and outer.tunable.tensor_b_pass_through): + self.v_gtc_ic = sym_t("v_gtc_ic" ,vseq(1)) + self.v_in_inb = sym_t("v_in_inb" ,vseq(1)) + self.v_in_in = sym_t("v_in_in" ,vseq(1)) + self.v_wei_ik = sym_t("v_wei_ik" ,vseq(1)) + + self.v_co_sst = sym_t("v_co_sst" ,self.v_in_in.value) + self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) + + self.v_out_flag = sym_t("v_out_flag" ,self.v_wei_ik.value) + self.v_out_inb = sym_t("v_out_inb" ,self.v_in_inb.value) + + self.v_gemm_in = sym_t("v_gemm_in" ,vseq(1)) + self.v_gemm_im = sym_t("v_gemm_im" ,vseq(1)) + self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,self.v_gemm_im.value) + self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,self.v_gemm_in.value) + + self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) + self.v_wei_tmp_pack = sym_t("v_wei_tmp_pack" ,self.v_gld_a.value - 1 if self.v_gld_a.value > 1 else vseq(1)) + if nk_per_thread <= 4 and IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG == 0: + self.v_wei_flag = sym_t("v_wei_flag" ,self.v_tmp.value) + else: + self.v_wei_flag = sym_t("v_wei_flag" ,vseq(nk_per_thread)) + + total_vgpr = vseq() + if outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + # if xdlops agpr is larger than vgpr usage, must change vgpr count to agpr + total_vgpr = max(total_vgpr, outer.tunable.num_agpr_accumulate_c) + self.v_end = sym_t("v_end" ,total_vgpr) + + def get_count(self): + return self.v_end.value + + def emit(self): + for k, v in self.__dict__.items(): + if k.startswith('v_'): + self._emit(v.declare()) + + class kernel_agpr_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + assert outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS, 'only xdlops can use agpr' + self.outer = outer + aseq = gpr_sequencer_t() + self.a_c = sym_t("a_c", aseq(outer.tunable.num_agpr_accumulate_c)) + self.a_end = sym_t("a_end", aseq()) + + def get_count(self): + return self.a_end.value + + def emit(self): + for k, v in self.__dict__.items(): + if k.startswith('a_'): + self._emit(v.declare()) + + def get_thread_lengths(self): + t_ta = self.tunable.tensor_a_thread_lengths + t_tb = self.tunable.tensor_b_thread_lengths + + assert len(t_ta) == 4 and len(t_tb) == 4 + + ta_e, ta_c, ta_nb0, ta_nb1 = t_ta[0], t_ta[1], t_ta[2], t_ta[3] + tb_e, tb_c, tb_k0, tb_k1 = t_tb[0], t_tb[1], t_tb[2], t_tb[3] + + if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through: + pass + else: + assert ta_e == tb_e and ta_c == tb_c + assert ta_c in (1, 2, 4), "currently c will be used as LDS store/load vector size, now only support this" + + assert ta_e == 1, "currently not support >1 in e dimension" + + # it's no point to have both x0, x1 have copy value + if not self.tunable.tensor_a_pass_through: + assert not (ta_nb0 != 1 and ta_nb1 != 1) + if not self.tunable.tensor_b_pass_through: + assert not (tb_k0 != 1 and tb_k1 != 1) + + return ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 # M, K, N + + def get_cluster_lengths(self): + c_ta = self.tunable.tensor_a_cluster_lengths + c_tb = self.tunable.tensor_b_cluster_lengths + + assert len(c_ta) == 4 and len(c_tb) == 4 + + ca_e, ca_c, ca_nb0, ca_nb1 = c_ta[0], c_ta[1], c_ta[2], c_ta[3] + cb_e, cb_c, cb_k0, cb_k1 = c_tb[0], c_tb[1], c_tb[2], c_tb[3] + + if not self.tunable.tensor_a_pass_through: + assert ca_nb1 != 1 + assert ca_e == cb_e and ca_c == cb_c + assert ca_nb0 == 1 + if not self.tunable.tensor_b_pass_through: + assert cb_k0 == 1 + + assert ca_e == 1 + + return ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 # M, K, N + + def get_dims_lengths(self): + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = self.get_cluster_lengths() + + na_nb0, na_nb1, na_e, na_c = ta_nb0 * ca_nb0, ta_nb1 * ca_nb1, ta_e * ca_e, ta_c * ca_c + nb_k0, nb_k1 , nb_e, nb_c = tb_k0 * cb_k0, tb_k1 * cb_k1, tb_e * cb_e, tb_c * cb_c + + return na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 # M, K, N + + def get_thread_copy_dims(self): + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + in_thread_copy_dims = [ta_nb0, ta_nb1, ta_e, ta_c] + wei_thread_copy_dims = [tb_k0, tb_k1, tb_e, tb_c] # always reordered! + return in_thread_copy_dims, wei_thread_copy_dims + + def get_thread_copy_index(self): + in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + in_thread_copy_index = _find_non_1_index_in_list(in_thread_copy_dims) + wei_thread_copy_index = _find_non_1_index_in_list(wei_thread_copy_dims) + + ''' + if thread lengths both dimension is 1, means every thread only copy one pixel. + we need support this also + ''' + return in_thread_copy_index, wei_thread_copy_index + + def get_k_pack(self): + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + if (not self.tunable.tensor_a_pass_through and not self.tunable.tensor_b_pass_through) or \ + (self.tunable.tensor_a_pass_through and self.tunable.tensor_b_pass_through): + assert ta_c == tb_c + return tb_c + else: + if self.tunable.tensor_a_pass_through: + assert ta_c % tb_c == 0 + return tb_c + else: + assert tb_c % ta_c == 0 + return ta_c + + def get_macro_global_load(self): + ''' + NOTICE: input/wei always load gemm_k (e*c) first. indeed always load c, and do vector load if possible + ''' + inline = True if self.tunable.fma_interleave else False + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() + + in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + ctrl_wei_gld = ctrl_2d_global_load_t() + ctrl_in_gld = ctrl_2d_global_load_t() + + ctrl_wei_gld.vector_d1 = utility_gcd(tb_c, 4) if tb_c != 1 else 1 + ctrl_in_gld.vector_d1 = utility_gcd(ta_c, 4) if ta_c != 1 else 1 + + if self.tunable.tensor_b_pass_through: + ctrl_wei_gld.length_d0 = tb_k0 if tb_k0 != 1 else tb_k1 + ctrl_wei_gld.length_d1 = tb_c + ctrl_wei_gld.vector_d1 = self.get_k_pack() + ctrl_wei_gld.flag_merge_v = 0 if self.tunable.tensor_b_pass_through_interleave_gld else 1 + else: + if self.wei_thread_copy_ndim == 2: + ctrl_wei_gld.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[1]] + elif self.wei_thread_copy_ndim == 1: + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] + else: + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[-1] + + if self.tunable.tensor_a_pass_through: + ctrl_in_gld.length_d0 = ta_c // self.get_k_pack() + ctrl_in_gld.length_d1 = (ta_nb0 if ta_nb0 != 1 else ta_nb1) * self.get_k_pack() + ctrl_in_gld.vector_d1 = self.get_k_pack() + ctrl_in_gld.flag_merge_v = 0 if self.tunable.tensor_a_pass_through_interleave_gld else 1 + else: + if self.in_thread_copy_ndim == 2: + ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] + ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[1]] + elif self.in_thread_copy_ndim == 1: + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + else: + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[-1] + + ctrl_in_gld.use_flag = 1 + ctrl_wei_gld.use_flag = 1 + + if self.tunable.nxe != 0: + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + ctrl_wei_gld.bfe_flag = 1 + ctrl_in_gld.bfe_flag = 1 + + if self.tunable.precache_soffset: + return macro_igemm_2d_global_load_precache_sv_offset_t(self.mc, ctrl_wei_gld, inline) if self.tunable.tensor_b_pass_through else \ + macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ + macro_igemm_2d_global_load_precache_sv_offset_t(self.mc, ctrl_in_gld, inline) if self.tunable.tensor_a_pass_through else \ + macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) + else: + return macro_igemm_2d_global_load_t(self.mc, ctrl_wei_gld, inline), macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) + + def get_macro_shared_store(self): + #in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + #in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + k_pack = self.get_k_pack() + + if not self.tunable.tensor_a_pass_through: + # input is gemm_k * gemm_m * k_pack + in_sst_ctrl = ctrl_3d_shared_store_t() + in_sst_ctrl.length_d0 = ta_nb0 + in_sst_ctrl.length_d1 = ta_nb1 + in_sst_ctrl.length_dp = k_pack + in_sst_ctrl.stride_d0 = na_nb1 * k_pack * data_byte + in_sst_ctrl.stride_d1 = k_pack * data_byte + + if not self.tunable.tensor_b_pass_through: + # wei is gemm_k * gemm_n * k_pack + wei_sst_ctrl = ctrl_3d_shared_store_t() + wei_sst_ctrl.length_d0 = tb_k0 + wei_sst_ctrl.length_d1 = tb_k1 + wei_sst_ctrl.length_dp = k_pack + wei_sst_ctrl.stride_d0 = nb_k1 * k_pack * data_byte + wei_sst_ctrl.stride_d1 = k_pack * data_byte + + inline = True if self.tunable.fma_interleave else False + return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline) if not self.tunable.tensor_a_pass_through else None, \ + macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) if not self.tunable.tensor_b_pass_through else None + + def get_macro_move_slice_window(self): + inline = True if self.tunable.fma_interleave else False + if self.tunable.nxe != 0: + move_slice_window = self.macro_move_slice_window_block_wise_t(self.mc, self.tunable, inline) + else: + move_slice_window = self.macro_move_slice_window_block_wise_1x1_t(self.mc, self.tunable, inline) + + # return single functor ! + return move_slice_window + + def get_macro_move_slice_window_accumulate(self): + inline = True if self.tunable.fma_interleave else False + if self.tunable.nxe != 0: + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + return self.macro_move_slice_window_block_wise_acc_yx_t(self.mc, self.tunable, inline, + label_acc_yx = self.name() + "_acc_yx", + nb_per_thread = nb_per_thread, + m_set_flag_nhw = self.get_macro_set_flag_nhw()) + else: + return None + + def get_macro_set_flag_nhw(self): + inline = True if self.tunable.fma_interleave else False + return self.macro_set_flag_nhw(self.mc, inline) + + def get_symbol_global_load_s_stride_d0_d1(self): + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + # get the symbol object that load 2d may use + s = self.sgpr + s_dummy = sym_t("s_dummy") + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + + # input is ignored + # [ta_nb0, ta_nb1, ta_e, ta_c] + in_stride_gprs = [s_dummy, + s_dummy, + s_dummy, + s_dummy] + + # [tb_k0, tb_k1, ta_e, ta_c] + wei_stride_gprs = [s.s_wei_stride_k0 if tb_k0 != 1 else s_dummy, + s.s_wei_stride_k if tb_k1 != 1 else s_dummy, + s_dummy, + s_dummy] + + if self.in_thread_copy_ndim == 2: + s_in_stride_d0 = in_stride_gprs[in_thread_copy_index[0]] + s_in_stride_d1 = in_stride_gprs[in_thread_copy_index[1]] + elif self.in_thread_copy_ndim == 1: + s_in_stride_d0 = s_dummy + s_in_stride_d1 = in_stride_gprs[in_thread_copy_index[0]] + else: + s_in_stride_d0 = s_dummy + s_in_stride_d1 = in_stride_gprs[-1] + + if self.wei_thread_copy_ndim == 2: + # print(f" ____ wei_thread_copy_index:{len(wei_thread_copy_index)}, {wei_thread_copy_index}") + s_wei_stride_d0 = wei_stride_gprs[wei_thread_copy_index[0]] + s_wei_stride_d1 = wei_stride_gprs[wei_thread_copy_index[1]] + elif self.wei_thread_copy_ndim == 1: + s_wei_stride_d0 = s_dummy + s_wei_stride_d1 = wei_stride_gprs[wei_thread_copy_index[0]] + else: + s_wei_stride_d0 = s_dummy + s_wei_stride_d1 = wei_stride_gprs[-1] + + return s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 + + def get_kernel_code(self): + kernel_code = amdgpu_kernel_code_t({ + 'enable_sgpr_kernarg_segment_ptr' : 1, + 'enable_sgpr_workgroup_id_x' : 1, + 'enable_sgpr_workgroup_id_y' : 1, + 'enable_vgpr_workitem_id' : 0, + 'workgroup_group_segment_byte_size' : self.tunable.lds_total, + 'kernarg_segment_byte_size' : self.karg.get_count(), + 'wavefront_sgpr_count' : self.sgpr.get_count() + 2*3, + 'workitem_vgpr_count' : self.vgpr.get_count() + }) + return kernel_code + + def get_kernel_args(self): + ''' + float *p_in; + float *p_wei; + float *p_out; + int hi; + int wi; + int n; + int k; + int c; + int ho; + int wo; + int stride_h; + int stride_w; + int dilation_h; + int dilation_w; + int pad_h; + int pad_w; + int y; + int x; + int group; + /* if use magic division */ + uint32_t magic_0; // denom: sa=0: n*b / n_per_block, sa=1: k / m_per_block + uint32_t magic_1; // denom: ((n / nb_n0) * b) / nb_n1b + uint32_t magic_2; // denom: y*x, if nxe==0 not used + uint32_t magic_3; // denom: x, if nxe==0 not used + uint32_t magic_4; // denom: b + uint32_t magic_5; // denom: wo + uint32_t magic_6; // denom: n*b*k / (m_per_block*n_per_block) + uint32_t shift_pack_0; + uint32_t shift_pack_1; + uint32_t __pack_0; + ''' + kas = [] + # name: {}, .size: {}, .offset: {}, .value_kind: {}, .value_type + kas.append(amdgpu_kernel_arg_t('p_in' , 8, 0, 'global_buffer','f32',address_space='global',is_const='true')) + kas.append(amdgpu_kernel_arg_t('p_wei' , 8, 8, 'global_buffer','f32',address_space='global',is_const='true')) + kas.append(amdgpu_kernel_arg_t('p_out' , 8, 16, 'global_buffer','f32',address_space='global',is_const='false')) + kas.append(amdgpu_kernel_arg_t('hi' , 4, 24, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('wi' , 4, 28, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('n' , 4, 32, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('k' , 4, 36, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('c' , 4, 40, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('ho' , 4, 44, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('wo' , 4, 48, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('stride_h' , 4, 52, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('stride_w' , 4, 56, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('dilation_h' , 4, 60, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('dilation_w' , 4, 64, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('pad_h' , 4, 68, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('pad_w' , 4, 72, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('y' , 4, 76, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('x' , 4, 80, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('group' , 4, 84, 'by_value','i32')) + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + kas.append(amdgpu_kernel_arg_t('magic_0' , 4, 88, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_1' , 4, 92, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_2' , 4, 96, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_3' , 4, 100, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('shift_pack_0' , 4, 104, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('__pack_0' , 4, 108, 'by_value','i32')) + else: + pass + return kas + + def get_kernel_info(self): + kernel_code = self.get_kernel_code() + kernel_args = self.get_kernel_args() + kernel_info = amdgpu_kernel_info_t(kernel_code, self.name(), self.tunable.block_size, kernel_args) + return kernel_info + + def get_kernel_macros(self): + kernel_macros = [] + for attrs in dir(self): + if attrs.startswith('get_macro_'): + functor = getattr(self, attrs) + rtn = functor() + if rtn is None: + continue + + # here we follow the convention in code: + # #1. for macro like emit class, use emit() to generate macro definition, use __call__() to call this macro + # #2. for non-macro like emit class, which might want to "inline-ed" into normal code, no emit() is defined, just __call__(). + # hence need to check if has attr name "emit". if not have, it is type #2, no need to do emit() before hand. + if type(rtn) is tuple: + for e in rtn: + #if hasattr(e, 'emit'): + if e is not None and not e.is_inline(): + #continue + kernel_macros.extend([m for m in rtn]) + else: + #if hasattr(rtn, 'emit'): + if rtn is not None and rtn.is_inline(): + #continue + kernel_macros.append(rtn) + return kernel_macros + + def emit_kernel_prologue(self): + s = self.sgpr + v = self.vgpr + k = self.karg + + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = self.get_cluster_lengths() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() + + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + m_set_flag_nhw = self.get_macro_set_flag_nhw() + s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.get_symbol_global_load_s_stride_d0_d1() + + m_wei_2d_global_load, m_in_2d_global_load = self.get_macro_global_load() + + tc_index_dispatcher = igemm_thread_cluster_index_dispatcher_t(self.mc) + tc_index_accumulator = igemm_thread_cluster_index_accumulator_t(self.mc) + + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + nk_per_thread = tb_k0 if tb_k0 != 1 else tb_k1 + + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + m_mdiv_u32_vs = macro_mdiv_u32_rem_vs_t(self.mc) + m_mdiv_u32_ss = macro_mdiv_u32_rem_ss_t(self.mc) + else: + m_int_div_rem_vv = macro_int_div_rem_vv_t(self.mc) + m_int_div_rem_vs = macro_int_div_rem_vs_t(self.mc) + m_int_div_rem_ss = macro_int_div_rem_ss_t(self.mc) + + s_dummy = sym_t("s_dummy") + + k_pack = self.get_k_pack() + + # start emit + self._emit(f"s_load_dwordx2 s[{s.s_p_in((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_in()}") + self._emit(f"s_load_dwordx2 s[{s.s_p_wei((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_wei()}") + self._emit(f"s_load_dwordx2 s[{s.s_p_out((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_out()}") + if self.tunable.nxe != 0: + self._emit(f"s_load_dwordx8 s[{s.s_hi((0, 7))}], s[{s.s_ka((0, 1))}], 0+{k.k_hi()}") + self._emit(f"s_load_dwordx8 s[{s.s_stride_w((0, 7))}], s[{s.s_ka((0, 1))}], 0+{k.k_stride_w()}") + else: + self._emit(f"s_load_dwordx4 s[{s.s_hi((0, 3))}], s[{s.s_ka((0, 1))}], 0+{k.k_hi()}") + self._emit(f"s_load_dword s[{s.s_c()}], s[{s.s_ka((0, 1))}], 0+{k.k_c()}") + self._emit(f"s_load_dword s[{s.s_group()}], s[{s.s_ka((0, 1))}], 0+{k.k_group()}") + + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_load_dwordx2 s[{s.s_magic_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_0()}") + self._emit(f"s_load_dwordx2 s[{s.s_magic_2((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_2()}") + self._emit(f"s_load_dword s[{s.s_shift_pack_0()}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") + self._emit(f"s_load_dword s[{s.s_gemmk_split()}], s[{s.s_ka((0, 1))}], 0+{k.k_gemm_k_global_split()}") + + self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") + if self.tunable.tensor_a_pass_through: + self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), ca_nb1, ta_nb1)) + self._emit(tc_index_dispatcher(v.v_gtc_ic_a(), v.v_tmp(), ca_c, k_pack)) # <= note here, thread length is further reduced! + self._emit(tc_index_dispatcher(v.v_tmp(1), v.v_tmp(), ca_nb0, ta_nb0, True)) + self._emit(tc_index_accumulator(v.v_in_inb(), v.v_tmp(1), v.v_in_inb(), ca_nb0, ca_nb1, na_nb0, na_nb1)) + else: + self._emit(tc_index_dispatcher(v.v_gtc_ic(), v.v_tmp(), ca_c, ta_c)) + self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), ca_nb1, ta_nb1, True)) + + self._emit(f"; wei(e, c, k0, k1) thread_length: {tb_e}x{tb_c}x{tb_k0}x{tb_k1}, cluster_length: {cb_e}x{cb_c}x{cb_k0}x{cb_k1}, k_pack:{k_pack}") + # weight ic same as input + if (not self.tunable.tensor_a_pass_through) and (not self.tunable.tensor_b_pass_through): + self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") + self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) + elif self.tunable.tensor_a_pass_through: + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") + self._emit(tc_index_dispatcher(v.v_gtc_ic(), v.v_tmp(), cb_c, tb_c)) + self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) + else: + assert False, "unimplemented" + self._emit_empty_line() + + + self._emit(f"s_waitcnt lgkmcnt(0)") + self._emit_empty_line() + self._emit(f"; calculate index") + # calculate stride, not shift data byte yet + # input + self._emit(f"s_lshr_b32 s[{s.s_sub_c()}], s[{s.s_c()}], s[{s.s_gemmk_split()}] ;add gkgs for c") + self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_tmp(2)}]") + + # weight + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_x()}], s[{s.s_c()}]") + self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_tmp()}], s[{s.s_y()}]") + else: + self._emit(f"s_mov_b32 s[{s.s_wei_stride_k()}], s[{s.s_c()}]") + + if tb_k0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_wei_stride_k0()}], s[{s.s_wei_stride_k()}], {igemm_log2(nb_k1)}") + + # output + self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], s[{s.s_out_stride_wo()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho() if self.tunable.nxe != 0 else s.s_hi()}], s[{s.s_tmp(1)}]") + + # calculate batch split and accumulate the base pointer for input/output + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + + # early init s_knum in case shifted + self._emit(f"s_lshr_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}], s[{s.s_gemmk_split()}]") + + # pad gemm_m, gemm_n + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[{s.s_dim_br()}], s[{s.s_ho()}], s[{s.s_wo()}]") + else: + self._emit(f"s_mul_i32 s[{s.s_dim_br()}], s[{s.s_hi()}], s[{s.s_wi()}]") + + self._emit(f"s_mul_i32 s[{s.s_dim_mr()}], s[{s.s_n()}], s[{s.s_dim_br()}]") + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_dim_mr()}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_mp()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_m_per_block)}") + + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_n_per_block - 1}, s[{s.s_k()}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_np()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_n_per_block)}") + + self._emit_empty_line() + self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}, source_access_order:{self.tunable.source_access_order}") + + # calculate group index + self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_dim_mp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_dim_np()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") + # calculate block ic + self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], 1, s[{s.s_gemmk_split()}]") + self._emit(f"s_sub_u32 s[{s.s_tmp(3)}], s[{s.s_tmp(3)}], 1") + self._emit(f"s_and_b32 s[{s.s_block_gtc_ic()}], s[{s.s_bx()}], s[{s.s_tmp(3)}]") + self._emit(f"s_lshr_b32 s[{s.s_bx()}], s[{s.s_bx()}], s[{s.s_gemmk_split()}]") + self._emit(f"s_mul_i32 s[{s.s_block_gtc_ic()}], s[{s.s_block_gtc_ic()}], s[{s.s_sub_c()}]") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080018 ; offset:24, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), s.s_magic_3(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + # s.s_tmp(4)=> rem, gemm_m, gemm_n, s.s_block_gtc_ig()=> quo, group + self._emit(f"s_mov_b32 s[{s.s_bx()}], s[{s.s_tmp(4)}]") + + if self.tunable.source_access_order == IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N: + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_np()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + else: + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_mp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + self._emit(f"; s_tmp+4:block_gtc_in, s_tmp+5:block_gtc_im") + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ik()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_inb()}], s[{s.s_tmp(5)}], {igemm_log2(self.tunable.gemm_m_per_block)}") + + # transform nb + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_in_inb()}]") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wo() if self.tunable.nxe != 0 else s.s_wi(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wo() if self.tunable.nxe != 0 else s.s_wi(), v.v_tmp(), s.s_tmp())) + + if self.tunable.nxe != 0: + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + self._emit(f"v_mul_lo_u32 v[{v.v_in_ihi_list(0)}], s[{s.s_stride_h()}], v[{v.v_in_ihi_list(0)}]") + self._emit(f"v_sub_i32 v[{v.v_in_ihi_list(0)}], v[{v.v_in_ihi_list(0)}], s[{s.s_pad_h()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_in_iwi_list(0)}], s[{s.s_stride_w()}], v[{v.v_in_iwi_list(0)}]") + self._emit(f"v_sub_i32 v[{v.v_in_iwi_list(0)}], v[{v.v_in_iwi_list(0)}], s[{s.s_pad_w()}]") + self._emit_empty_line() + + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + # update flag for batch size + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshlrev_b32 v[{v.v_in_flag(0)}], 16, v[{v.v_tmp()}]") + else: + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshlrev_b32 v[{v.v_in_flag_n()}], 0, v[{v.v_tmp()}]") + + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ig()}], s[{s.s_block_gtc_ig()}], {igemm_log2(data_byte)}") + + def calculate_and_load_input(): + self._emit(f"; calculate in offset") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_mov_b32 s[{s.s_in_c_itr()}], 0") + self._emit(f"s_mov_b32 s[{s.s_in_stride_k_pack()}], {ca_c * k_pack * data_byte}") + for i in range(2, ta_c // k_pack): + self._emit(f"s_mul_i32 s[{s.s_in_offset(i - 2)}], s[{s.s_in_stride_k_pack()}], {i}") + else: + self._emit(f"s_mov_b32 s[{s.s_in_offset()}], 0") + # compute group distance + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") + self._emit(f"s_add_u32 s[{s.s_p_in(0)}], s[{s.s_p_in(0)}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + self._emit_empty_line() + + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") + # s_in_stride_wi need shift before! + self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) + + self._emit(f"v_add_u32 v[{v.v_tmp(1)}], v[{v.v_tmp(1)}], s[{s.s_block_gtc_ic()}]") + + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic_a() if self.tunable.tensor_a_pass_through else v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(0)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + + if True: + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag()}], 16, 1") + self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], 0, v[{v.v_in_flag()}]") + else: + self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag_n()}], 0, 1") + self._emit(m_set_flag_nhw(v.v_in_flag(0), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) + self._emit_empty_line() + + # voffset, for [1, nb_per_thread) pixels + if self.tunable.tensor_a_pass_through: + thread_stride = ca_nb0 * ca_nb1 + else: + thread_stride = na_nb1 if ta_nb0 != 1 else 1 + + for i in range(1, nb_per_thread): + self._emit(f"s_mov_b32 s1, {thread_stride * i}") + self._emit(f"v_add_u32 v[{v.v_tmp()}], s1, v[{v.v_in_inb()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_tmp()}]") + if self.tunable.nxe != 0: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + self._emit(f"v_mul_lo_u32 v[{v.v_in_ihi_list(i)}], s[{s.s_stride_h()}], v[{v.v_in_ihi_list(i)}]") + self._emit(f"v_sub_i32 v[{v.v_in_ihi_list(i)}], v[{v.v_in_ihi_list(i)}], s[{s.s_pad_h()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_in_iwi_list(i)}], s[{s.s_stride_w()}], v[{v.v_in_iwi_list(i)}]") + self._emit(f"v_sub_i32 v[{v.v_in_iwi_list(i)}], v[{v.v_in_iwi_list(i)}], s[{s.s_pad_w()}]") + self._emit_empty_line() + + else: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) + + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(1)}], v[{v.v_tmp(1)}], s[{s.s_block_gtc_ic()}]") + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic_a() if self.tunable.tensor_a_pass_through else v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(i)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(i)}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os(i)}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + # update flag for batch size + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], {16 + i}, v[{v.v_in_flag(0)}]") + self._emit(m_set_flag_nhw(v.v_tmp(1), v.v_tmp(), v.v_in_ihi_list(i), v.v_in_iwi_list(i), s.s_hi(), s.s_wi())) + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp(1)}], {i}, v[{v.v_in_flag()}]") + else: + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag_n()}], v[{v.v_tmp()}], {i}, v[{v.v_in_flag_n()}]") + self._emit(m_set_flag_nhw(v.v_in_flag(i), v.v_tmp(), v.v_in_ihi_list(i), v.v_in_iwi_list(i), s.s_hi(), s.s_wi())) + + # load in + self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") + if self.tunable.tensor_a_pass_through and self.tunable.tensor_a_pass_through_interleave_gld: + mbb_gld_in = create_machine_basic_block(self.global_load_in()) + gld_per_k = self.tunable.wave_repeat_m * self.tunable.wave_step_m + for i_mbb in mbb_gld_in[0:(-1 * gld_per_k)]: + # TODO: need multiple load of pass through side + self._emit(machine_basic_block_call(self, i_mbb)) + else: + self._emit(self.global_load_in()) + self._emit_empty_line() + + def calculate_and_load_weight(): + self._emit(f"; calculate wei offset") + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_k()}], s[{s.s_wei_stride_k()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") + self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_tmp()}], s[{s.s_block_gtc_ic()}]") + self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") + + # wei flag + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cndmask_b32 v[{v.v_wei_flag()}], 0, 1, vcc") + self._emit(f"v_mov_b32 v[{v.v_wei_tmp_pack()}], v[{v.v_wei_flag()}]") + + for i in range(1, nk_per_thread): + if i == 1: + k_thread_stride = nb_k1 if tb_k0 != 1 else 1 + self._emit(f"s_mov_b32 s[{s.s_tmp()}], {k_thread_stride}") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_tmp()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cndmask_b32 v[{v.v_wei_flag(i)}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_wei_tmp_pack()}], v[{v.v_wei_flag(i)}], {i}, v[{v.v_wei_tmp_pack()}]") + + self._emit_empty_line() + if self.wei_thread_copy_ndim != 1: + if s_wei_stride_d0 != s_dummy: + self._emit(self.try_shift_stride(s_wei_stride_d0, igemm_log2(data_byte))) + if s_wei_stride_d1 != s_dummy: + self._emit(self.try_shift_stride(s_wei_stride_d1, igemm_log2(data_byte))) + self._emit_empty_line() + + if self.tunable.precache_soffset: + self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + + self._emit(f".v_clear_nc {v.v_gld_b()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") + self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") + if self.tunable.tensor_b_pass_through and self.tunable.tensor_b_pass_through_interleave_gld: + mbb_gld_wei = create_machine_basic_block(self.global_load_wei()) + gld_per_k = self.tunable.wave_repeat_n * self.tunable.wave_step_n + for i_mbb in mbb_gld_wei[0:(-1 * gld_per_k)]: + # TODO: need multiple load of pass through side + self._emit(machine_basic_block_call(self, i_mbb)) + else: + self._emit(self.global_load_wei()) + self._emit_empty_line() + + # do load + calculate_and_load_weight() + calculate_and_load_input() + + + + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") + self._emit(self.thread_mapping(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp())) + else: + v_pack = k_pack if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through else 1 + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") + self._emit(self.xdlops_mapping.get_gemm_index_for_src_matrix(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp(), + k_pack=k_pack, v_pack=v_pack)) + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") + self._emit(self.xdlops_mapping.get_gemm_index_for_dst_matrix(v.v_co_sst(), v.v_co_sld(), v.v_tmp(5), v.v_tmp())) + + ''' + gemm_k * gemm_m * k_pack + ''' + if not self.tunable.tensor_a_pass_through: + self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_in_inb()}]") + self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_tmp(2)}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_in_inb()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") + self._emit_empty_line() + self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_im()}] ; LDS load in") + + if not self.tunable.tensor_b_pass_through: + self._emit(f"; LDS store, wei: e,c,k: {tb_e}x{tb_c}x{tb_k0}x{tb_k1}, {cb_e}x{cb_c}x{cb_k0}x{cb_k1}, k_pack:{k_pack}") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_wei_ik()}]") + self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_tmp(2)}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_wei_ik()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") + if not self.tunable.tensor_a_pass_through: + self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") + self._emit_empty_line() + self._emit(f"v_lshlrev_b32 v[{v.v_sld_b_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_in()}] ; LDS load wei") + if not self.tunable.tensor_a_pass_through: + self._emit(f"v_add_u32 v[{v.v_sld_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sld_b_os()}]") + + # self._emit(f"; LDS load") + #if not self.tunable.tensor_a_pass_through: + #self._emit(f"v_lshlrev_b32 v[{v.v_sld_b_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_in()}]") + #self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_im()}]") + #self._emit(f"v_add_u32 v[{v.v_sld_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sld_b_os()}]") + #self._emit_empty_line() + + if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self._emit(f"v_mov_b32 v[{v.v_gemm_in()}], v[{v.v_co_sst()}]") + self._emit(f"v_mov_b32 v[{v.v_gemm_im()}], v[{v.v_co_sld()}]") + self._emit(self.coalescing_store.init_co_lds_offset(v.v_co_sst(), v.v_co_sld(), v.v_gemm_im(), v.v_gemm_in(), '0', v.v_tmp())) + self._emit(self.coalescing_store.init_co_sub_m_index(v.v_co_sub_m_index(), '0', v.v_tmp())) + self._emit(self.coalescing_store.init_co_sub_n_index(v.v_co_sub_n_index(), '0', v.v_tmp())) + self._emit_empty_line() + + #if self.tunable.nxe != 0: + if True: + self._emit(f"v_add_u32 v[{v.v_tmp()}], s[{s.s_block_gtc_ik()}], v[{v.v_co_sub_n_index()}]") + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp()}]") + self._emit(f"v_cndmask_b32 v[{v.v_out_flag()}], 0, 1, vcc") + + ''' + a good news for nhwc and coalescing output is that, we can treat gemm_m (n*ho*wo) as a single dimension, + and use sgpr to stride along this dimension. this is much easier + ''' + self._emit(f"; output offset") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_k()}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_k()}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + + self._emit_empty_line() + self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_ik()}], {igemm_log2(data_byte)}") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp(3)}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out()}+1], 0") + self._emit_empty_line() + + self._emit(self.try_shift_stride(s.s_out_stride_wo, igemm_log2(data_byte))) + self._emit(f"v_add_u32 v[{v.v_out_inb()}], s[{s.s_block_gtc_inb()}], v[{v.v_co_sub_m_index()}] ; total n*ho*wo") + self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_wo()}], v[{v.v_out_inb()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_tmp()}], {igemm_log2(data_byte)}, v[{v.v_co_sub_n_index()}]") + self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp()}]") + + self._emit(f"; move slice stride") + self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_sub_c()}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") + + w_flag_cnt = 0 + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(0)}], v[{v.v_wei_tmp_pack()}], {0}, 1") + w_flag_cnt = w_flag_cnt + 1 + + # if self.tunable.nxe != 0: + # self._emit(f"s_mov_b32 s[{s.s_tmp()}], {na_c}") + # self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_tmp()}], {igemm_log2(data_byte)}") + # else: + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") + if w_flag_cnt < nk_per_thread: + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_wei_tmp_pack()}], {w_flag_cnt}, 1") + w_flag_cnt = w_flag_cnt + 1 + + if self.tunable.nxe != 0: + # s_in_diff_wi : s_dilation_w * s_in_stride_wi + # s_in_diff_hi : s_dilation_h * s_in_stride_hi - (x - 1) * s_dilation_w * s_in_stride_wi, always possitive + # s_dilation_w_x : -1* (x - 1) * s_dilation_w + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_ix()}], 0") + self._emit(f"s_mul_i32 s[{s.s_in_diff_wi()}], s[{s.s_dilation_w()}], s[{s.s_in_stride_wi()}]") # shifted + self._emit(f"s_sub_i32 s[{s.s_tmp(3)}], s[{s.s_x()}], 1") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_in_diff_wi()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_in_stride_wi()}], s[{s.s_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_tmp(1)}], s[{s.s_dilation_h()}]") + self._emit(f"s_sub_i32 s[{s.s_in_diff_hi()}], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_dilation_w_x()}], s[{s.s_dilation_w()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mul_i32 s[{s.s_dilation_w_x()}], s[{s.s_dilation_w_x()}], -1") + + self._emit_empty_line() + + self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") + if w_flag_cnt < nk_per_thread: + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_wei_tmp_pack()}], {w_flag_cnt}, 1") + w_flag_cnt = w_flag_cnt + 1 + self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x27000") + for i_w in range(w_flag_cnt, nk_per_thread): + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(i_w)}], v[{v.v_wei_tmp_pack()}], {i_w}, 1") + + def emit_kernel_fma_main_loop(self): + s = self.sgpr + v = self.vgpr + k = self.karg + + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + k_pack = self.get_k_pack() + + m_move_slice_window = self.get_macro_move_slice_window() + m_move_slice_window_accumulate = self.get_macro_move_slice_window_accumulate() + + def move_slice_window_b(): + ''' + in nhwc we only need call one move slice window + ''' + if self.tunable.nxe != 0: + with self._deferred_context(): + self._emit(m_move_slice_window( + *(s.s_p_in(), s.s_in_c_itr()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), + v.v_wei_os(), + s.s_move_slice_k_stride_c(), + s.s_gemm_k_num_c(), + s.s_flag_need_acc_yx())) + return self._get_deferred() + else: + with self._deferred_context(): + self._emit(m_move_slice_window( + s.s_p_in() if self.tunable.tensor_a_pass_through else s.s_in_offset(), + v.v_wei_os(), + s.s_move_slice_k_stride_c())) + return self._get_deferred() + + def move_slice_window_a(): + return '' + + def move_slice_window_acc(): + if self.tunable.nxe == 0: + return '' + else: + with self._deferred_context(): + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(m_move_slice_window_accumulate( + *(s.s_p_in(), s.s_in_c_itr(), s.s_gemm_k_num_c()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), + v.v_in_os(), + v.v_in_ihi_list(), + v.v_in_iwi_list(), + v.v_in_flag(), + s.s_flag_need_acc_yx(), + s.s_move_slice_k_ix(), + s.s_x(), + s.s_in_diff_hi(), + s.s_in_diff_wi(), + s.s_dilation_h(), + s.s_dilation_w(), + s.s_dilation_w_x(), + s.s_hi(), + s.s_wi(), + v.v_tmp(), + s.s_tmp())) + else: + self._emit(m_move_slice_window_accumulate( + *(s.s_p_in(), s.s_in_c_itr()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), + v.v_wei_os(), + s.s_c(), + s.s_gemm_k_num_c(), + v.v_in_os(), + v.v_in_ihi_list(), + v.v_in_iwi_list(), + v.v_in_flag(), + v.v_in_flag_n(), + s.s_flag_need_acc_yx(), + s.s_move_slice_k_ix(), + s.s_x(), + s.s_in_diff_hi(), + s.s_in_diff_wi(), + s.s_dilation_h(), + s.s_dilation_w(), + s.s_dilation_w_x(), + s.s_hi(), + s.s_wi(), + v.v_tmp(), + s.s_tmp())) + return self._get_deferred() + + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + fctrl = ctrl_fma_main_loop_t() + fctrl.thread_m = self.tunable.thread_tile_m + fctrl.thread_n = self.tunable.thread_tile_n + fctrl.unroll_k = self.tunable.gemm_k_per_block + fctrl.label_prefix = self.name() + fctrl.gemm_m_repeat = self.tunable.gemm_m_repeat + fctrl.gemm_m_level0_cluster = self.tunable.gemm_m_level0_cluster + fctrl.gemm_m_level1_cluster = self.tunable.gemm_m_level1_cluster + fctrl.gemm_n_repeat = self.tunable.gemm_n_repeat + fctrl.gemm_n_level0_cluster = self.tunable.gemm_n_level0_cluster + fctrl.gemm_n_level1_cluster = self.tunable.gemm_n_level1_cluster + fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 + fctrl.lds_buffer_num = self.tunable.lds_buffer_num + + # functor + fctrl.global_load_a_functor = self.global_load_wei + fctrl.global_load_b_functor = self.global_load_in + fctrl.shared_store_a_functor = self.shared_store_wei + fctrl.shared_store_b_functor = self.shared_store_in + fctrl.shared_load_a_functor = inst_ds_read_t(self.tunable.thread_sub_tile_m * data_byte) + fctrl.shared_load_b_functor = inst_ds_read_t(self.tunable.thread_sub_tile_n * data_byte) + fctrl.move_slice_window_a_functor = move_slice_window_a + fctrl.move_slice_window_b_functor = move_slice_window_b + + # sympol type + fctrl.v_a = v.v_a + fctrl.v_b = v.v_b + fctrl.v_c = v.v_c + fctrl.v_gld_a = v.v_gld_a + fctrl.v_gld_b = v.v_gld_b + fctrl.v_sld_a_os = v.v_sld_a_os + fctrl.v_sld_b_os = v.v_sld_b_os + fctrl.v_sst_a_os = v.v_sst_a_os + fctrl.v_sst_b_os = v.v_sst_b_os + fctrl.s_kitr = s.s_kitr + fctrl.s_knum = s.s_knum + + fma_main_loop = fma_main_loop_t(self.mc, fctrl) + fma_main_loop.emit() + else: + a = self.agpr + fctrl = ctrl_mfma_main_loop_t() + ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile_fp32(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, + self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, + self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE) + fctrl.cxm = ctrl_xdlops_mapping + fctrl.unroll_k = self.tunable.gemm_k_per_block + fctrl.label_prefix = self.name() + fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 + fctrl.lds_buffer_num = self.tunable.lds_buffer_num + fctrl.local_prefetch_num = self.tunable.local_prefetch_num + fctrl.interleave = self.tunable.fma_interleave + + # functor + # fctrl.global_load_a_functor = self.global_load_wei + # fctrl.global_load_b_functor = self.global_load_in + # fctrl.shared_store_a_functor = self.shared_store_wei + # fctrl.shared_store_b_functor = self.shared_store_in + fctrl.global_load_a_functor = self.global_load_in + fctrl.global_load_b_functor = self.global_load_wei + fctrl.shared_store_a_functor = self.shared_store_in + fctrl.shared_store_b_functor = self.shared_store_wei + + # ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + fctrl.lds_k_pack = k_pack + + share_load_packed = k_pack if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through else 1 + + if ctrl_xdlops_mapping.wave_step_m == 1: + fctrl.shared_load_a_functor = inst_ds_read_t(data_byte * share_load_packed) # xdlops load from LDS always single load + else: + assert ctrl_xdlops_mapping.wave_step_m == 2, "currently only support wave_step_m is 2" + fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte * share_load_packed, k_pack*ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) + + if ctrl_xdlops_mapping.wave_step_n == 1: + fctrl.shared_load_b_functor = inst_ds_read_t(data_byte * share_load_packed) # xdlops load from LDS always single load + else: + assert ctrl_xdlops_mapping.wave_step_n == 2, "currently only support wave_step_n is 2" + fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte * share_load_packed, k_pack*ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) + fctrl.move_slice_window_a_functor = move_slice_window_a + fctrl.move_slice_window_b_functor = move_slice_window_b + fctrl.move_slice_window_accumule_functor = move_slice_window_acc if self.tunable.nxe != 0 else None + + # sympol type + fctrl.v_a = v.v_a if not self.tunable.tensor_a_pass_through else None + fctrl.v_b = v.v_b if not self.tunable.tensor_b_pass_through else None + fctrl.a_c = a.a_c + fctrl.v_gld_a = v.v_gld_a + fctrl.v_gld_b = v.v_gld_b + fctrl.v_gld_a_gpf = v.v_gld_a_gpf if self.tunable.global_prefetch_a_num == 2 else None + fctrl.v_gld_b_gpf = v.v_gld_b_gpf if self.tunable.global_prefetch_b_num == 2 else None + fctrl.v_gld_a_num = self.tunable.num_vgpr_global_load_a + fctrl.v_gld_b_num = self.tunable.num_vgpr_global_load_b + fctrl.v_sld_a_os = v.v_sld_a_os if not self.tunable.tensor_a_pass_through else None + fctrl.v_sld_b_os = v.v_sld_b_os if not self.tunable.tensor_b_pass_through else None + fctrl.v_sst_a_os = v.v_sst_a_os if not self.tunable.tensor_a_pass_through else None + fctrl.v_sst_b_os = v.v_sst_b_os if not self.tunable.tensor_b_pass_through else None + fctrl.s_kitr = s.s_kitr + fctrl.s_knum = s.s_knum + fctrl.pass_through_a = self.tunable.tensor_a_pass_through + fctrl.pass_through_b = self.tunable.tensor_b_pass_through + fctrl.pass_through_a_v_pack = self.get_k_pack() + fctrl.pass_through_b_v_pack = self.get_k_pack() + + fctrl.pass_through_a_interleave_gld = 1 if self.tunable.tensor_a_pass_through_interleave_gld else 0 + fctrl.pass_through_b_interleave_gld = 1 if self.tunable.tensor_b_pass_through_interleave_gld else 0 + + mfma_main_loop = mfma_main_loop_t(self.mc, fctrl) + mfma_main_loop.emit() + + + def emit_kernel_epilogue(self): + s = self.sgpr + v = self.vgpr + #label_out = f"L_{self.name()}_out" + + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = self.get_cluster_lengths() + + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + # if self.tunable.nxe != 0: + # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, + # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_stride_c(), s.s_tmp(), v.v_in_flag())) + # else: + # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, + # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_stride_c(), s.s_tmp())) + assert False + else: + a = self.agpr + self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, + None, s.s_out_stride_wo(), + s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else v.v_out_flag(), s.s_dim_mr(), v.v_out_inb(), s.s_block_gtc_inb(), v.v_co_sub_m_index(), v.v_tmp())) + + self._emit_front(f"{self.label_out}:") + + def emit_kernel_symbol(self): + self.karg.emit() + self._emit_empty_line() + self.sgpr.emit() + self._emit_empty_line() + self.vgpr.emit() + self._emit_empty_line() + if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self.agpr.emit() + self._emit_empty_line() + + def emit_kernel_header(self): + kernel_name = self.name() + self._emit('.text') + if self.mc.arch_config.code_object == AMDGPU_CODEOBJECT_V3: + self._emit('.globl {}'.format(kernel_name)) + self._emit('.p2align 8') + if self.mc.arch_config.code_object == AMDGPU_CODEOBJECT_V3: + self._emit('.type {},@function'.format(kernel_name)) + if self.mc.arch_config.code_object == AMDGPU_CODEOBJECT_V2: + self._emit('.amdgpu_hsa_kernel {}'.format(kernel_name)) + self._emit('{}:'.format(kernel_name)) + + def emit_kernel_body(self): + self.emit_kernel_prologue() + self.emit_kernel_fma_main_loop() + self.emit_kernel_epilogue() + def emit_kernel_end(self): + self._emit('s_endpgm') + def emit_kernel_footer(self): + self._emit_empty_line() + + def emit_kernel_amd_kernel_code_t(self): + amd_kernel_code_t(self.mc, self.get_kernel_info()).emit() diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 13f7d61c..57b2edbc 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -30,6 +30,9 @@ from .mfma import * from .xdlops_mapping import * from .nop import * +import re + +MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT = 1 # last wait for ds_read advance a mfma slot class ctrl_mfma_main_loop_t(object): def __init__(self): @@ -53,13 +56,18 @@ def __init__(self): self.shared_load_b_functor = None self.move_slice_window_a_functor = None self.move_slice_window_b_functor = None + self.move_slice_window_accumule_functor = None # symbol type self.v_a = None self.v_b = None self.a_c = None self.v_gld_a = None + self.v_gld_a_gpf = None # used for a pass through and not interleaved, as global prefetch register + self.v_gld_a_num = 1 self.v_gld_b = None + self.v_gld_b_gpf = None # used for b pass through and not interleaved, as global prefetch register + self.v_gld_b_num = 1 self.v_sld_a_os = None self.v_sld_b_os = None self.v_sst_a_os = None @@ -67,6 +75,18 @@ def __init__(self): self.s_kitr = None self.s_knum = None + # below is in unit of pixel, not considered data_type bytes + self.lds_k_pack = 1 + self.lds_pad_m = 0 # pad how many pixels per m row + self.lds_pad_n = 0 # pad how many pixels per n row + + self.pass_through_a = 0 # a tensor not using LDS + self.pass_through_b = 0 # b tensor not using LDS + self.pass_through_a_v_pack = 1 # passthough tensor may have v pack, indicate vector load + self.pass_through_b_v_pack = 1 + self.pass_through_a_interleave_gld = 1 + self.pass_through_b_interleave_gld = 1 + class mfma_main_loop_t(mc_base_t): ''' ''' @@ -74,7 +94,431 @@ def __init__(self, mc, ctrl): mc_base_t.__init__(self, mc) self.ctrl = ctrl assert type(ctrl) is ctrl_mfma_main_loop_t + + def emit_single_pass_through(self): + ''' + one side of A/B tensor not using LDS, used for skinny gemm + a/b -> p/q, where p side passthrough lds, q side is normal + ''' + + p_idx = 0 if self.ctrl.pass_through_a else 1 + q_idx = p_idx ^ 1 + ctrl = self.ctrl + + label_mfma_body = 'L_{}_mfma_body'.format(self.ctrl.label_prefix) + label_mfma_finishing = 'L_{}_mfma_finishing'.format(self.ctrl.label_prefix) + label_mfma_end = 'L_{}_mfma_end'.format(self.ctrl.label_prefix) + + f_gld_p = [ctrl.global_load_a_functor, ctrl.global_load_b_functor][p_idx] + f_gld_q = [ctrl.global_load_a_functor, ctrl.global_load_b_functor][q_idx] + f_sst_p = [ctrl.shared_store_a_functor, ctrl.shared_store_b_functor][p_idx] + f_sst_q = [ctrl.shared_store_a_functor, ctrl.shared_store_b_functor][q_idx] + + f_sld_p = [ctrl.shared_load_a_functor, ctrl.shared_load_b_functor][p_idx] + f_sld_q = [ctrl.shared_load_a_functor, ctrl.shared_load_b_functor][q_idx] + + f_move_slice_window_p = [ctrl.move_slice_window_a_functor, ctrl.move_slice_window_b_functor][p_idx] + f_move_slice_window_q = [ctrl.move_slice_window_a_functor, ctrl.move_slice_window_b_functor][q_idx] + f_move_slice_window_acc = ctrl.move_slice_window_accumule_functor + + v_gld_p = [ctrl.v_gld_a, ctrl.v_gld_b][p_idx] + v_gld_q = [ctrl.v_gld_a, ctrl.v_gld_b][q_idx] + + v_gld_p_gpf = [ctrl.v_gld_a_gpf, ctrl.v_gld_b_gpf][p_idx] + v_gld_p_num = [ctrl.v_gld_a_num, ctrl.v_gld_b_num][p_idx] + + a_c = ctrl.a_c + v_q = [ctrl.v_a, ctrl.v_b][q_idx] + v_sld_q_os = [ctrl.v_sld_a_os, ctrl.v_sld_b_os][q_idx] + v_sst_q_os = [ctrl.v_sst_a_os, ctrl.v_sst_b_os][q_idx] + + s_kitr = ctrl.s_kitr + s_knum = ctrl.s_knum + cxm = ctrl.cxm + + data_byte = amdgpu_precision_data_byte(ctrl.data_type) + + lds_width_m = data_byte * cxm.wave_tile_m * cxm.wave_step_m * cxm.waves_per_m() * cxm.wave_repeat_m + lds_width_n = data_byte * cxm.wave_tile_n * cxm.wave_step_n * cxm.waves_per_n() * cxm.wave_repeat_n + lds_single_size = ctrl.lds_single_size + + lds_width_q = [lds_width_m, lds_width_n][q_idx] + + # used as offset:x number. may some + lds_base_m = 0 + lds_base_n = 0 + assert ctrl.unroll_k % cxm.block_k() == 0 + unroll_k = ctrl.unroll_k + k_per_inst = cxm.block_k() + + pad_m = ctrl.lds_pad_m + pad_n = ctrl.lds_pad_n + + lds_base_q = [lds_base_m, lds_base_n][q_idx] + pad_q = [pad_m, pad_n][q_idx] + + num_v_p = [cxm.inst_mfma.num_v_a, cxm.inst_mfma.num_v_b][p_idx] + num_v_q = [cxm.inst_mfma.num_v_a, cxm.inst_mfma.num_v_b][q_idx] + wave_step_p = [cxm.wave_step_m, cxm.wave_step_n][p_idx] + wave_step_q = [cxm.wave_step_m, cxm.wave_step_n][q_idx] + wave_repeat_p = [cxm.wave_repeat_m, cxm.wave_repeat_n][p_idx] + wave_repeat_q = [cxm.wave_repeat_m, cxm.wave_repeat_n][q_idx] + + p_interleave_gld = [ctrl.pass_through_a_interleave_gld, ctrl.pass_through_b_interleave_gld][p_idx] + + # assert wave_repeat_q == 2, "currently the side need LDS must have repeat 2, following limitation seems have BUG" + + v_pack_p = [ctrl.pass_through_a_v_pack, ctrl.pass_through_b_v_pack][p_idx] + v_pack_q = [ctrl.pass_through_a_v_pack, ctrl.pass_through_b_v_pack][q_idx] + assert v_pack_p == v_pack_q, "currently only support p, q the same" + + assert unroll_k % (v_pack_p * k_per_inst) == 0 + unroll_k_slot = unroll_k // (v_pack_p * k_per_inst) + + def global_load_p(): + with self._deferred_context(): + self._emit(f_gld_p()) + return self._get_deferred() + + def global_load_q(): + with self._deferred_context(): + self._emit(f_gld_q()) + return self._get_deferred() + + def move_slice_window_pq(): + with self._deferred_context(): + if f_move_slice_window_p: + self._emit(f_move_slice_window_p()) + if f_move_slice_window_q: + self._emit(f_move_slice_window_q()) + return self._get_deferred() + + def move_slice_window_acc(): + with self._deferred_context(): + self._emit(f_move_slice_window_acc()) + return self._get_deferred() + + def call_mbb(mbb): + return machine_basic_block_call(self, mbb) + + # parse global load of p tensor into list of single load + mbb_gld_p = create_machine_basic_block(global_load_p()) + mbb_gld_q = create_machine_basic_block(global_load_q(), merge_mbb = 1) + + mbb_p_clear = 1 if mbb_gld_p[0].mc_inst(-1).type() == MC_INST_TYPE_LEGACY_MACRO else 0 + mbb_q_clear = 1 if mbb_gld_q[0].mc_inst(-1).type() == MC_INST_TYPE_LEGACY_MACRO else 0 + + if mbb_p_clear == 1: + # hack on v_clear_nc + v_clear_nc_strs = mbb_gld_p[0].mc_inst(-1).inst_str + v_clear_nc_list = re.split('[,\s]+', v_clear_nc_strs) + assert len(v_clear_nc_list) == 3 and v_clear_nc_list[0] == '.v_clear_nc' + num_gld_p = int(v_clear_nc_list[2]) # TODO: check number + assert num_gld_p % (len(mbb_gld_p) - mbb_p_clear) == 0 + num_gld_p_per_issue = num_gld_p // (len(mbb_gld_p) - mbb_p_clear) + def emit_v_clear_nc_p(i): + with self._deferred_context(): + self._emit(f".v_clear_nc {v_gld_p(i * num_gld_p_per_issue) if p_interleave_gld else v_gld_p_gpf(i * num_gld_p_per_issue)}, {num_gld_p_per_issue}") + return self._get_deferred() + + mbb_gld_p_wrapper = list() + for i in range(len(mbb_gld_p) - mbb_p_clear): + mbb_gld_p_wrapper += create_machine_basic_block(emit_v_clear_nc_p(i) + '\n' + call_mbb(mbb_gld_p[i+1]), merge_mbb = 1) + + mbb_gld_p = mbb_gld_p_wrapper + mbb_p_clear = 0 + + num_p_issue = len(mbb_gld_p) - mbb_p_clear + num_q_issue = len(mbb_gld_q) - mbb_q_clear + + mbb_msw_pq = create_machine_basic_block(move_slice_window_pq(), merge_mbb = 1) if (f_move_slice_window_p or f_move_slice_window_q) else list() + mbb_msw_acc = create_machine_basic_block(move_slice_window_acc(), merge_mbb = 1) if f_move_slice_window_acc else list() + + def mapped_ioffset(i_k, width_byte, pad_pixel, offset = 0): + k_pack = self.ctrl.lds_k_pack + i_k0 = i_k // k_pack + i_kp = i_k % k_pack + return i_k0 * (width_byte * k_pack + pad_pixel * data_byte) + i_kp * data_byte + offset * k_pack + + def mi_q(i_k, offset = 0): + return mapped_ioffset(i_k, lds_width_q, pad_q, offset) + + def mfma_step_pxq_vk(i_k, i_repeat_p, i_repeat_q, i_v, i_local_buffer_q = 0): + # v_pack is in k direction, hence c_index stay the same across different i_v + mfma = cxm.inst_mfma + num_agpr_per_issue = mfma.num_a_c + with self._deferred_context(): + for i_step_q in range(wave_step_q): + for i_step_p in range(wave_step_p): + if p_idx == 0: + c_index = i_repeat_p * wave_step_p * wave_step_q * wave_repeat_q * num_agpr_per_issue + \ + i_repeat_q * wave_step_p * wave_step_q * num_agpr_per_issue + \ + i_step_p * wave_step_q * num_agpr_per_issue + \ + i_step_q * num_agpr_per_issue + else: + c_index = i_repeat_q * wave_step_q * wave_step_p * wave_repeat_p * num_agpr_per_issue + \ + i_repeat_p * wave_step_q * wave_step_p * num_agpr_per_issue + \ + i_step_q * wave_step_p * num_agpr_per_issue + \ + i_step_p * num_agpr_per_issue + c_index_end = c_index + num_agpr_per_issue - 1 + + p_index = i_k * wave_repeat_p * wave_step_p * v_pack_p * num_v_p + \ + i_repeat_p * wave_step_p * v_pack_p * num_v_p + \ + i_step_p * v_pack_p * num_v_p + \ + i_v * num_v_p + + q_index = i_local_buffer_q * wave_step_q * wave_repeat_q * v_pack_q * num_v_q + \ + i_repeat_q * wave_step_q * v_pack_q * num_v_q + \ + i_step_q * v_pack_q * num_v_q + \ + i_v * num_v_q + self._emit(mfma(a_c((c_index, c_index_end)), v_gld_p(p_index), v_q(q_index), a_c((c_index, c_index_end))) + f" ; repeat:{i_repeat_p}x{i_repeat_q}, step:{i_step_p}x{i_step_q}, k:{i_k}, v:{i_v}, num_a_c:{num_agpr_per_issue}") + return self._get_deferred() + + def mfma_loop(): + mfma = cxm.inst_mfma + + repeat_q_thread_offset = wave_step_q * num_v_q * v_pack_p + local_buffer_q = wave_repeat_q * repeat_q_thread_offset + mfma_v_pack_slot = unroll_k_slot * wave_repeat_p * wave_repeat_q # TODO: not consider step + cnt_mfma_v_pack_slot = 0 + + def first_sld(): + # when start of mfma main loop, do this load + with self._deferred_context(): + for i in range(wave_repeat_q): + self._emit(f_sld_q(v_q(i * repeat_q_thread_offset), v_sld_q_os(), lds_base_q + mi_q(0, i * (lds_width_q // 2)))) + if ctrl.local_prefetch_num == 2: + # always load a single piece of repeat + self._emit(f_sld_q(v_q(wave_step_q * wave_repeat_q * v_pack_q * num_v_q), v_sld_q_os(), lds_base_q + mi_q(1 * v_pack_p * k_per_inst, 0))) + return self._get_deferred() + + mbb_first_sld = create_machine_basic_block(first_sld()) + + def mfma_per_k_slot(i_k, i_mfma_v_pack_slot, is_last_fma): + ''' + k slot is unroll_k / k_per_inst + pattern: + prefetch:1, repeat:1 (phase:1) + 0 0 0 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 0 0 1 0 + 1 0 0 0 2 0 + 2 0 0 0 3 0 + 3 0 0 0 4 0 x + + prefetch:2, repeat:1 (phase:1) + 0 0 0 + 0 1 1 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 0 0 2 1 + 1 0 0 1 3 1 + 2 0 0 0 4 1 x + 3 0 0 1 5 0 x + + prefetch:1, repeat:2 (phase:2) + 0 0 0 + 1 0 0 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 0 0 1 1 + 0 1 1 0 1 1 + 1 0 0 0 2 1 + 1 1 1 0 2 1 + 2 0 0 0 3 1 + 2 1 0 0 3 1 + 3 0 1 0 4 1 x + 3 1 1 0 4 0 x + + prefetch:2, repeat:2 (phase:3) + 0 0 0 + 1 0 0 + 0 1 1 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 1 1 1 2 + 0 1 0 0 2 2 + 1 0 1 0 2 2 + 1 1 0 1 3 2 + 2 0 1 1 3 2 + 2 1 0 0 4 2 x + 3 0 1 0 4 1 x + 3 1 0 1 5 0 x + ''' + pref = ctrl.local_prefetch_num + rept = wave_repeat_q + phase = pref + rept - 1 # idx before entering main loop + + i_r_sequence = [ x & (rept - 1) for x in range(pref * rept)] + i_b_sequence = [(x >> (rept - 1)) & (pref - 1) for x in range(pref * rept)] + + i_local_buffer_q = i_k & 1 if pref == 2 else 0 + i_k_sst_q = i_k == (unroll_k_slot - ctrl.local_prefetch_num) + # print(f"i_k:{i_k}, i_k_sst_q:{i_k_sst_q}") + gld_p_per_k = wave_repeat_p * wave_step_p + cnt_mfma = 0 + def try_do_gld_per_slot(i_slot): + if is_last_fma: + if p_interleave_gld: + mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] if i_k == 0 else list() + else: + mbb_gld_p_per_k = list() + mbb_gld_per_k = mbb_gld_p_per_k + else: + if p_interleave_gld: + if i_k == 0: + mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] + else: + start_p_idx = mbb_p_clear if i_k == 1 else ((i_k - 1) * gld_p_per_k + mbb_p_clear) # always no clear + mbb_gld_p_per_k = mbb_gld_p[start_p_idx : i_k * gld_p_per_k + mbb_p_clear ] + else: + mbb_gld_p_per_k = mbb_gld_p if i_k == 0 else list() + mbb_gld_per_k = ((mbb_gld_p_per_k + mbb_msw_pq + mbb_msw_acc + mbb_gld_q) if p_interleave_gld else (mbb_gld_q + mbb_gld_p_per_k)) \ + if i_k == 0 else mbb_gld_p_per_k + num_gld_slot_per_k = wave_repeat_p * wave_repeat_q * v_pack_p + num_gld_per_slot = utility_next_mul(len(mbb_gld_per_k), num_gld_slot_per_k) // num_gld_slot_per_k + for i_gld in range(num_gld_per_slot): + current_gld = i_slot * num_gld_per_slot + i_gld + if current_gld < len(mbb_gld_per_k): + self._emit(call_mbb(mbb_gld_per_k[current_gld])) + + def do_sst_q(): + # print(f"do_sst_q, i_k:{i_k}") + if ctrl.lds_buffer_num == 1: + self._emit(f"s_barrier") + self._emit(f_sst_q()) + if ctrl.lds_buffer_num != 1: + self._emit(f"v_xor_b32 v[{v_sst_q_os()}], {hex(lds_single_size)}, v[{v_sst_q_os()}]") + + def do_sld_q(i_v, i_r): + # interleave into different v_pack + i_idx = i_k * rept + i_r + i_idx_mod = (i_idx + phase) % (pref * rept) + i_idx_int = (i_idx + phase) // (pref * rept) + + # print(f" ==i_r_sequence:{i_r_sequence}, i_b_sequence:{i_b_sequence}, i_idx:{i_idx}, mod:{i_idx_mod}, int:{i_idx_int}") + + load_i_r = i_r_sequence[i_idx_mod] + load_i_b = i_b_sequence[i_idx_mod] + load_i_k = i_idx_int * pref + load_i_b + + if i_v == (v_pack_p - 1) and load_i_k < unroll_k_slot: + the_str = f' ; i_r:{load_i_r}, i_b:{load_i_b}, i_k:{load_i_k}' + self._emit(f_sld_q(v_q(load_i_b * local_buffer_q + load_i_r * repeat_q_thread_offset), v_sld_q_os(), lds_base_q + mi_q(load_i_k * v_pack_p * k_per_inst, load_i_r * (lds_width_q // 2) )) + the_str) + + + if i_k == 0: + if not p_interleave_gld and not is_last_fma: + self._emit(move_slice_window_pq()) + for mbb_1st in mbb_first_sld[1:]: + self._emit(call_mbb(mbb_1st)) + if not p_interleave_gld and not is_last_fma: + self._emit(move_slice_window_acc()) + + for i_rp in range(wave_repeat_p): + # cnt_p_load = cnt_p_load + 1 + for i_rq in range(wave_repeat_q): + num_lgkmcnt = (pref + rept - 2) - ((pref - 1 + i_rq) if i_k == (unroll_k_slot-1) else 0) + if not p_interleave_gld: + vmcnt_str = "vmcnt(0)" if i_k == 0 and i_rp == 0 and i_rq == 0 else \ + ( f"vmcnt({f_gld_p.get_issues()})" if num_lgkmcnt == 0 and not is_last_fma else "") + else: + if i_rq != 0 and wave_repeat_q != 1: + vmcnt_str = "" + else: + if i_k == 0: + vmcnt_str = f'vmcnt({num_p_issue - 1 - gld_p_per_k})' + else: + if not is_last_fma: + vmcnt_str = f'vmcnt({num_p_issue + num_q_issue - 2})' + else: + vmcnt_str = f'vmcnt({num_p_issue - i_k - 1})' + + if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT and num_lgkmcnt == 0 and p_interleave_gld: + # we need a chance to put last lgkmcnt earlier + # assert vmcnt_str == "" + if is_last_fma: + self._emit(f's_waitcnt lgkmcnt(0) ; vmcnt_str:{vmcnt_str}') + else: + # self._emit(f"; __ vmcnt_str:{vmcnt_str}") + pass + else: + self._emit(f's_waitcnt lgkmcnt({num_lgkmcnt}) {vmcnt_str}') + if num_lgkmcnt == 0 and not p_interleave_gld and not is_last_fma: + # self._emit(move_slice_window_acc()) + do_sst_q() + if i_k == 0 and i_rp == 0 and i_rq == 0: + if not p_interleave_gld and v_gld_p_gpf: + # move buffer + for i_pnum in range(v_gld_p_num): + self._emit(f"v_mov_b32 v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}]") + + for i_v in range(v_pack_p): + self._emit(mfma_step_pxq_vk(i_k, i_rp, i_rq, i_v, i_local_buffer_q)) + if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT and p_interleave_gld: + if (i_mfma_v_pack_slot == mfma_v_pack_slot - 2) and (v_pack_p == 1 or i_v == (v_pack_p // 2) - 1): + assert i_rq == 0 + if not is_last_fma: + self._emit(f's_waitcnt lgkmcnt(0) vmcnt({num_p_issue - gld_p_per_k})') + do_sst_q() + do_sld_q(i_v, i_rq) # will not emit when last ds wait, hence will never co-exist when last ds wait emit + #if not is_last_fma: + try_do_gld_per_slot(cnt_mfma) + cnt_mfma = cnt_mfma + 1 + assert i_mfma_v_pack_slot < mfma_v_pack_slot, f'i_mfma_v_pack_slot:{i_mfma_v_pack_slot}, mfma_v_pack_slot:{mfma_v_pack_slot}' + i_mfma_v_pack_slot = i_mfma_v_pack_slot + 1 + + if not is_last_fma and i_k == (unroll_k_slot - 1): + self._emit(f's_waitcnt lgkmcnt(0)') + self._emit(f"s_barrier") + self._emit(call_mbb(mbb_first_sld[0])) + self._emit(f"s_sub_i32 s[{s_kitr()}], s[{s_kitr()}], {unroll_k}") + self._emit(f"s_cmp_gt_i32 s[{s_kitr()}], 0") + self._emit(f"s_cbranch_scc1 {label_mfma_body}") + return i_mfma_v_pack_slot + + self._emit(call_mbb(mbb_first_sld[0])) + self._emit(f"s_sub_i32 s[{s_kitr()}], s[{s_knum()}], {unroll_k}") + self._emit(f"s_cmp_gt_i32 s[{s_kitr()}], 0") + self._emit(f"s_cbranch_scc0 {label_mfma_end}") + self._emit_empty_line() + + self._emit_front(f"{label_mfma_body}:") + self._emit(f"; do fma accumulate with unroll {unroll_k}, mfma_v_pack_slot:{mfma_v_pack_slot}") + + for i_k in range(unroll_k_slot): + cnt_mfma_v_pack_slot = mfma_per_k_slot(i_k, cnt_mfma_v_pack_slot, False) + + self._emit_front(f"{label_mfma_end}:") + cnt_mfma_v_pack_slot = 0 + for i_k in range(unroll_k_slot): + cnt_mfma_v_pack_slot = mfma_per_k_slot(i_k, cnt_mfma_v_pack_slot, True) + + # start emit, first load q tensor, then p tensor. + self._emit(f"; start MFMA loop, wave tile:{cxm.wave_tile_m}x{cxm.wave_tile_n}, repeat:{cxm.wave_repeat_m}x{cxm.wave_repeat_n}, step:{cxm.wave_step_m}x{cxm.wave_step_n}" +\ + f", k_pack:{self.ctrl.lds_k_pack}, p_issue:{num_p_issue}, q_issue:{num_q_issue}, local_prefetch_num:{ctrl.local_prefetch_num}") + + self._emit(f".v_clear_acc_c {a_c()}, {cxm.total_acc_c()}") + # self._emit(f"; make sure acc WAR harzard, at least 1 nop for src_c") + + self._emit(f"s_waitcnt vmcnt({f_gld_p.get_issues() - ((wave_repeat_p * wave_step_p) if p_interleave_gld else 0)})") + self._emit(f_sst_q()) + self._emit_empty_line() + # if not p_interleave_gld: + # self._emit(move_slice_window_pq()) + # self._emit(move_slice_window_acc()) + + self._emit(f"s_waitcnt lgkmcnt(0)") + self._emit(f"s_barrier") + self._emit_empty_line() + + mfma_loop() + + nop = emit_nop_t(self.mc) + nop(cxm.inst_mfma.get_nop_count_mfma_acc_raw()) # solve dependency + + def emit(self): + if self.ctrl.pass_through_a ^ self.ctrl.pass_through_b: + return self.emit_single_pass_through() + label_mfma_body = 'L_{}_mfma_body'.format(self.ctrl.label_prefix) label_mfma_finishing = 'L_{}_mfma_finishing'.format(self.ctrl.label_prefix) label_mfma_end = 'L_{}_mfma_end'.format(self.ctrl.label_prefix) @@ -89,6 +533,7 @@ def emit(self): f_move_slice_window_a = self.ctrl.move_slice_window_a_functor f_move_slice_window_b = self.ctrl.move_slice_window_b_functor + f_move_slice_window_acc = self.ctrl.move_slice_window_accumule_functor v_a = self.ctrl.v_a v_b = self.ctrl.v_b @@ -119,6 +564,22 @@ def emit(self): unroll_k = self.ctrl.unroll_k k_per_inst = cxm.block_k() + pad_m = self.ctrl.lds_pad_m + pad_n = self.ctrl.lds_pad_n + + def mapped_ioffset(i_k, width_byte, pad_pixel, offset = 0): + k_pack = self.ctrl.lds_k_pack + i_k0 = i_k // k_pack + i_kp = i_k % k_pack + return i_k0 * (width_byte * k_pack + pad_pixel * data_byte) + i_kp * data_byte + offset * k_pack + + # mi = mapped_ioffset + def mi_m(i_k, offset = 0): + return mapped_ioffset(i_k, lds_width_m, pad_m, offset) + + def mi_n(i_k, offset = 0): + return mapped_ioffset(i_k, lds_width_n, pad_n, offset) + def mfma_step_mxn(i_repeat_m, i_repeat_n, i_local_buffer_m = 0, i_local_buffer_n = 0): local_buffer_m = cxm.inst_mfma.num_v_a * cxm.wave_step_m * cxm.wave_repeat_m local_buffer_n = cxm.inst_mfma.num_v_b * cxm.wave_step_n * cxm.wave_repeat_n @@ -152,6 +613,8 @@ def mfma_loop_repeat_1x1_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -165,21 +628,21 @@ def mfma_loop_repeat_1x1_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst))) # lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst))) # lds_width_n * k_per_inst def do_unroll_k_1x1_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 for i_k in range(unroll_k_sub): self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst)) - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst)) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst))) # (2*i_k+2) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst))) # (2*i_k+2) * lds_width_n * k_per_inst self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_n * k_per_inst do_unroll_k_1x1_sub() self._emit(f_move_slice_window_b()) @@ -189,6 +652,8 @@ def do_unroll_k_1x1_sub(): self._emit_empty_line() + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -214,15 +679,14 @@ def do_unroll_k_1x1_sub(): self._emit_front(f"{label_mfma_finishing}:") self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit_front(f"{label_mfma_end}:") self._emit("s_waitcnt lgkmcnt(0)") self._emit("s_barrier") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst))) # k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst))) # k_per_inst * lds_width_n do_unroll_k_1x1_sub() self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) @@ -243,13 +707,13 @@ def do_interleave_unroll_k_sub(): for i_k in range(unroll_k_sub): self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst))) # (2*i_k+2) * k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst))) # (2*i_k+2) * k_per_inst * lds_width_n self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst))) # (2*i_k+3) * k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst))) # (2*i_k+3) * k_per_inst * lds_width_n return self._get_deferred() def do_interleave_gload_and_move_slice_window(): @@ -291,6 +755,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -304,8 +770,8 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst) )) # k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst) )) # k_per_inst * lds_width_n if (unroll_k // k_per_inst) // 2 - 1 != 0: @@ -318,6 +784,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1)) else: mbb_list_last = [create_machine_basic_block(do_interleave_unroll_k_last(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -325,6 +793,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1)) # Label: finishing of fma body @@ -338,8 +808,8 @@ def do_interleave_share_store(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst))) # k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst))) # k_per_inst * lds_width_n self._emit(do_interleave_unroll_k_sub()) self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) @@ -359,6 +829,8 @@ def mfma_loop_repeat_2x2_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) #self._emit(f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") #self._emit(f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") @@ -374,8 +846,8 @@ def mfma_loop_repeat_2x2_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 def do_unroll_k_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 @@ -385,65 +857,65 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_n * k_per_inst else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") # (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt({3 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst, lds_width_n // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_n * k_per_inst + lds_width_n // 2 + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst, lds_width_m // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_m * k_per_inst + lds_width_m // 2 else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt({4 if i_k == 0 else 5})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_n * k_per_inst self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 0, 0)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2) + \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2)) + \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2 self._emit_empty_line() self._emit(f"; k iteration : {2 * i_k + 1}") # 1st fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2) )+ \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 1, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst) ) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") # (2*i_k+3) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(1, 0, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") # (2*i_k+3) * lds_width_n * k_per_inst self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n//2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") # (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2 if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((unroll_k // k_per_inst - 1) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") # (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() do_unroll_k_sub() @@ -457,6 +929,8 @@ def do_unroll_k_sub(): self._emit_empty_line() # 2nd fma + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -525,8 +999,8 @@ def do_unroll_k_sub(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 do_unroll_k_sub() self._emit(f"; k iteration : {unroll_k - 2}") # 1st fma @@ -585,65 +1059,65 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_n * k_per_inst else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") # (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt({3 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst ,lds_width_n // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_n * k_per_inst + lds_width_n // 2 + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst , lds_width_m // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_m * k_per_inst + lds_width_m // 2 else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt({4 if i_k == 0 else 5})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_n * k_per_inst self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 0, 0)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2) + \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2)) + \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2 self._emit_empty_line() self._emit(f"; k iteration : {2 * i_k + 1}") # 1st fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 1, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") # (2*i_k+3) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(1, 0, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n//2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") # (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2 if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((unroll_k // k_per_inst - 1) * k_per_inst, lds_width_m//2)) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") self._emit_empty_line() return self._get_deferred() @@ -714,6 +1188,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -723,8 +1199,8 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 if (unroll_k // k_per_inst) // 2 - 1 != 0: mbb_list_sub = [create_machine_basic_block(do_interleave_unroll_k_sub(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -736,6 +1212,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) else: @@ -744,6 +1222,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) @@ -763,8 +1243,8 @@ def do_interleave_share_store(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 self._emit(do_interleave_unroll_k_sub()) self._emit(f"; k iteration : {unroll_k - 2}") @@ -816,6 +1296,8 @@ def mfma_loop_repeat_2x2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") self._emit(f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") @@ -831,12 +1313,13 @@ def mfma_loop_repeat_2x2(): self._emit(f"; do fma accumulate with unroll {unroll_k // k_per_inst}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 - self._emit(f".itr_k = 0") - self._emit(f".rept {unroll_k // k_per_inst - 1}") - with self._indent_context(): + # self._emit(f".itr_k = 0") + # self._emit(f".rept {unroll_k // k_per_inst - 1}") + #with self._indent_context(): + for i_k in range( unroll_k // k_per_inst - 1): # 1st fma self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0)) @@ -846,21 +1329,25 @@ def mfma_loop_repeat_2x2(): self._emit(mfma_step_mxn(0, 1)) # 3rd fma - self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + # self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((i_k+1) * k_per_inst))) self._emit(f's_waitcnt lgkmcnt(1)') self._emit(mfma_step_mxn(1, 0)) # 4th fma - self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + # self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((i_k+1)* k_per_inst))) self._emit(mfma_step_mxn(1, 1)) self._emit_empty_line() # last - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) - self._emit('.itr_k = .itr_k + 1') + # self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) + # self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((i_k+1) * k_per_inst, lds_width_n//2))) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((i_k+1) * k_per_inst, lds_width_m//2))) + # self._emit('.itr_k = .itr_k + 1') - self._emit(f".endr") + # self._emit(f".endr") self._emit_empty_line() self._emit(f"; last unroll") self._emit(f"v_xor_b32 v[{v_sld_b_os()}], {lds_single_size}, v[{v_sld_b_os()}] ; switch double buffer b load") @@ -875,6 +1362,8 @@ def mfma_loop_repeat_2x2(): self._emit(mfma_step_mxn(0, 1)) # wait global and store to LDS + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") self._emit(f_sst_b()) self._emit(f"s_waitcnt vmcnt(0)") @@ -920,12 +1409,13 @@ def mfma_loop_repeat_2x2(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) - self._emit(f".itr_k = 0") - self._emit(f".rept {unroll_k // k_per_inst - 1}") - with self._indent_context(): + # self._emit(f".itr_k = 0") + # self._emit(f".rept {unroll_k // k_per_inst - 1}") + # with self._indent_context(): + for i_k in range(unroll_k // k_per_inst - 1): # 1st fma self._emit('s_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0)) @@ -935,18 +1425,22 @@ def mfma_loop_repeat_2x2(): self._emit(mfma_step_mxn(0, 1)) # 3rd fma - self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + # self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((i_k+1)* k_per_inst))) self._emit('s_waitcnt lgkmcnt(1)') self._emit(mfma_step_mxn(1, 0)) # 4th fma - self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + # self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((i_k+1)* k_per_inst))) self._emit(mfma_step_mxn(1, 1)) self._emit_empty_line() # last - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) + #self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) + #self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((i_k+1) * k_per_inst, lds_width_n//2))) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((i_k+1) * k_per_inst, lds_width_m//2))) self._emit('.itr_k = .itr_k + 1') self._emit('.endr') self._emit_empty_line() @@ -978,6 +1472,8 @@ def mfma_loop_repeat_2x1_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -991,7 +1487,7 @@ def mfma_loop_repeat_2x1_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 def do_unroll_k_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 @@ -1001,33 +1497,33 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_n * k_per_inst + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_m * k_per_inst if unroll_k_sub == 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst, lds_width_m // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_m * k_per_inst + lds_width_m // 2 elif i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ - f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") # (2*i_k+1) * lds_width_m * k_per_inst + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") # (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 else: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ - f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") # (2*i_k+1) * lds_width_m * k_per_inst self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1035,11 +1531,11 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1047,11 +1543,11 @@ def do_unroll_k_sub(): self._emit(mfma_step_mxn(1, 0, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() do_unroll_k_sub() @@ -1065,6 +1561,8 @@ def do_unroll_k_sub(): self._emit_empty_line() # 2nd fma + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -1106,7 +1604,7 @@ def do_unroll_k_sub(): self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) do_unroll_k_sub() self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1146,20 +1644,20 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst,lds_width_m // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") self._emit_empty_line() @@ -1167,12 +1665,12 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1180,11 +1678,11 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1192,11 +1690,11 @@ def do_interleave_unroll_k_sub(): self._emit(mfma_step_mxn(1, 0, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() return self._get_deferred() @@ -1254,6 +1752,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1263,7 +1763,7 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) if (unroll_k // k_per_inst) // 2 - 1 != 0: mbb_list_sub = [create_machine_basic_block(do_interleave_unroll_k_sub(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -1276,6 +1776,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) else: @@ -1284,6 +1786,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) @@ -1298,7 +1802,7 @@ def do_interleave_share_store(): self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) self._emit(do_interleave_unroll_k_sub()) self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1333,6 +1837,8 @@ def mfma_loop_repeat_1x2_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1346,7 +1852,7 @@ def mfma_loop_repeat_1x2_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) def do_unroll_k_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 @@ -1356,20 +1862,20 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") self._emit_empty_line() @@ -1377,12 +1883,12 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1390,11 +1896,11 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1402,11 +1908,11 @@ def do_unroll_k_sub(): self._emit(mfma_step_mxn(0, 1, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n // 2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() do_unroll_k_sub() @@ -1420,6 +1926,8 @@ def do_unroll_k_sub(): self._emit_empty_line() # 2nd fma + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -1461,7 +1969,7 @@ def do_unroll_k_sub(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) do_unroll_k_sub() self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1500,20 +2008,20 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") self._emit_empty_line() @@ -1521,12 +2029,12 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1534,11 +2042,11 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1546,11 +2054,11 @@ def do_interleave_unroll_k_sub(): self._emit(mfma_step_mxn(0, 1, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() return self._get_deferred() @@ -1606,6 +2114,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1615,7 +2125,7 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) if (unroll_k // k_per_inst) // 2 - 1 != 0: mbb_list_sub = [create_machine_basic_block(do_interleave_unroll_k_sub(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -1628,6 +2138,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) else: @@ -1636,6 +2148,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) @@ -1650,7 +2164,7 @@ def do_interleave_share_store(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) self._emit(do_interleave_unroll_k_sub()) self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1676,7 +2190,7 @@ def do_interleave_share_store(): # start emit - self._emit(f"; start MFMA loop, {cxm.wave_tile_m}x{cxm.wave_tile_n} wave tile with {cxm.wave_repeat_m}x{cxm.wave_repeat_n} repeat, {cxm.wave_step_m}x{cxm.wave_step_n} step") + self._emit(f"; start MFMA loop, {cxm.wave_tile_m}x{cxm.wave_tile_n} wave tile with {cxm.wave_repeat_m}x{cxm.wave_repeat_n} repeat, {cxm.wave_step_m}x{cxm.wave_step_n} step, k_pack:{self.ctrl.lds_k_pack}") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") self._emit(f_sst_b()) diff --git a/igemm/algo/shared_memory.py b/igemm/algo/shared_memory.py index 7b55cffe..fcda13c6 100644 --- a/igemm/algo/shared_memory.py +++ b/igemm/algo/shared_memory.py @@ -844,3 +844,86 @@ def get_issues(self): with self._deferred_context(): self.emit() return self.issue_cnt + +class ctrl_3d_shared_store_t(object): + ''' + d0 x d1 x dp (d pack) + ''' + def __init__(self): + self.length_d0 = 1 # is d0 is 1, it is indeed 1d access + self.length_d1 = 1 + self.length_dp = 1 + self.stride_d0 = 1 # stride + self.stride_d1 = 1 # if have stride_d1, then each d1 may have stride + self.precision = 'fp32' # 'fp32', 'fp16', ... + self.src_order = 0 # 0-d0,d1, 1-d1,d0 + self.need_transpose = 1 + self.v_tmp = None # used when order is 1 and consider shuffle + + def serialize(self): + return f"length_d0:{self.length_d0}, length_d1:{self.length_d1}, length_dp:{self.length_dp}, stride_d0:{self.stride_d0}, stride_d1:{self.stride_d1}, precision:{self.precision}, src_order:{self.src_order}" + +class macro_igemm_3d_shared_store_t(macro_base_t): + ''' + this is indeed for + 0: gemm_k * gemm_m/n * k_pack, src_order = 0 + 1: gemm_m/n * gemm_k * k_pack, src_order = 1 (unsupported) + we always want to use k_pack as vector store + ''' + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_3d_shared_store_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.issue_cnt = 0 + self.declare_arg("v_src") + self.declare_arg("v_sst_os") + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + return f".v_sst_so{ctrl.src_order}_{ctrl.length_d0}x{ctrl.length_d1}x{ctrl.length_dp}_{bits_str}" + \ + f"_st{ctrl.stride_d0}x{ctrl.stride_d1}" + + def expr(self): + ctrl = self.ctrl + assert ctrl.precision == 'fp32', "TO BE supported" + data_byte = amdgpu_precision_data_byte(ctrl.precision) + issue_cnt = 0 + + if ctrl.length_d0 == 1 or ctrl.length_d1 == 1: + # this is indeed a 2d case. + ds_write = inst_ds_write_t(ctrl.length_dp * data_byte) + if ctrl.length_d0 == 1 and ctrl.length_d1 == 1: + # further, 1d case + self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}')) + issue_cnt += ds_write.get_issues() + + else: + length_d = ctrl.length_d0 if ctrl.length_d0 != 1 else ctrl.length_d1 + stride_d = ctrl.stride_d0 if ctrl.length_d0 != 1 else ctrl.stride_d1 + if length_d % 2 == 0 and data_byte == 4 and ctrl.length_dp in (1, 2): + ds_write2 = inst_ds_write2_likely_t(self.mc, 2, ctrl.length_dp * data_byte, stride_d) + for i_d in range(length_d // 2): + self._emit(ds_write2(f'{self.v_sst_os()}', f'{self.v_src()}+{2 * i_d*ctrl.length_dp}', 2 * i_d * stride_d)) + issue_cnt += ds_write2.get_issues(2 * i_d * stride_d) + else: + # nhwc almost all case goes here + for i_d in range(length_d): + self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}+{i_d*ctrl.length_dp}', i_d * stride_d)) + issue_cnt += ds_write.get_issues() + else: + assert False, "un implemented yet" + + self.issue_cnt = issue_cnt + + def get_issues(self): + #assert False, "tobe implemented" + #return self.ctrl.length_d0 + with self._deferred_context(): + self.emit() + return self.issue_cnt diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index 728d0259..f9a5c46b 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -282,15 +282,23 @@ def serialize(self): ctrl_xdlops_mapping_t( 128, 128, 32, 64, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 4, 2, 1, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 4, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 32, 2, 4, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 2, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 64, 32, 1, 2, 1, 2, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 64 , 128, 32, 32, 2, 2, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 1, 2, 2, 2, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 32 , 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 128, 32 , 32, 32, 2, 4, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 32 , 32, 32, 2, 2, 2, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 32 , 128, 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 32 , 128, 32, 32, 2, 2, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 2, 4, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), # this is not as good as 16x16x4 @@ -302,15 +310,20 @@ def serialize(self): ctrl_xdlops_mapping_t( 16 , 128, 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), # need re-design coalescing. or do irregular gemm ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 2, 1, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 48 , 16, 16, 4, 4, 1, 3, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 32, 2, 2, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 64 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 2, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), #ctrl_xdlops_mapping_t( 256, 4 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm #ctrl_xdlops_mapping_t( 4 , 256, 4 , 64, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 16 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 16 , 16, 16, 4, 2, 2, 1, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 16 , 64, 16, 1, 1, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 16 , 64 , 4 , 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 16 , 64 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 16 , 64 , 16, 16, 4, 2, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), @@ -369,42 +382,65 @@ def __init__(self, mc, ctrl): mc_base_t.__init__(self, mc) assert type(ctrl) is ctrl_xdlops_mapping_t self.ctrl = ctrl - def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp4): + def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp4, **options): ''' notice! this is to calculate LDS offset for A/B matrix input, it is not the same as C matrix output layout, due to xdlops C matrix output describe is in coalescint_store ''' + def get_dict_with_default(some_dict, key, default_value): + if key in some_dict: + return some_dict[key] + return default_value ctrl = self.ctrl #print(f"ctrl.block_n()={ctrl.block_n()}, ctrl.block_m()={ctrl.block_m()}") #print(f"ctrl.block_n_per_wave()={ctrl.block_n_per_wave()}, ctrl.block_m_per_wave()={ctrl.block_m_per_wave()}") assert ctrl.block_n() == ctrl.block_m() and ctrl.block_k() * ctrl.block_n() * ctrl.block_n_per_wave() * ctrl.block_m_per_wave() == AMDGPU_WAVE_SIZE + k_pack = get_dict_with_default(options, "k_pack", 1) + v_pack = get_dict_with_default(options, "v_pack", 1) + assert v_pack in (1, k_pack), 'currently only support v_pack is 1 or k_pack' with self._deferred_context(): - self._emit(f"; xdlops mapping, get source matrix gemm index") + self._emit(f"; xdlops mapping, get source matrix gemm index, k_pack:{k_pack}, v_pack:{v_pack}") self._emit(f"v_and_b32 v[{v_gemm_in}], {ctrl.block_n() - 1}, v[{v_thread_id}] ; block_n index ") self._emit(f"v_and_b32 v[{v_gemm_im}], {ctrl.block_m() - 1}, v[{v_thread_id}] ; block_m index ") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v_gemm_in}], {utility_log2(k_pack)}, v[{v_gemm_in}] ; shift left k_pack:{k_pack}") + self._emit(f"v_lshlrev_b32 v[{v_gemm_im}], {utility_log2(k_pack)}, v[{v_gemm_im}] ; shift left k_pack:{k_pack}") + self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_n())}, v[{v_thread_id}]") if ctrl.block_k() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 0], {ctrl.block_k() - 1}, v[{v_thread_id}] ; block_k_per_wave index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n)}, v[{v_gemm_in}]") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m)}, v[{v_gemm_im}]") + if k_pack != 1: + if v_pack == 1: + self._emit(f"v_and_b32 v[{v_tmp4} + 1], {k_pack - 1}, v[{v_tmp4} + 0] ; and k_pack:{k_pack}") + self._emit(f"v_lshrrev_b32 v[{v_tmp4} + 0], {utility_log2(k_pack)}, v[{v_tmp4} + 0] ; shift right k_pack:{k_pack}") + self._emit(f"v_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 1], v[{v_gemm_in}] ; or k_pack:{k_pack}") + self._emit(f"v_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], v[{v_gemm_im}] ; or k_pack:{k_pack}") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n * k_pack)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m * k_pack)}, v[{v_gemm_im}]") + else: + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n * k_pack)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m * k_pack)}, v[{v_gemm_im}]") + else: + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m)}, v[{v_gemm_im}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_k())}, v[{v_thread_id}]") pass if ctrl.block_n_per_wave() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 0], {ctrl.block_n_per_wave() - 1}, v[{v_thread_id}] ; block_n_per_wave index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.block_n())}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.block_n() * k_pack)}, v[{v_gemm_in}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_n_per_wave())}, v[{v_thread_id}]") if ctrl.block_m_per_wave() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 1], {ctrl.block_m_per_wave() - 1}, v[{v_thread_id}] ; block_m_per_wave index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], {utility_log2(ctrl.block_m())}, v[{v_gemm_im}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], {utility_log2(ctrl.block_m() * k_pack)}, v[{v_gemm_im}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_m_per_wave())}, v[{v_thread_id}]") if ctrl.waves_per_n() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 2], {ctrl.waves_per_n() - 1}, v[{v_thread_id}] ; waves_per_n index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 2], {utility_log2(ctrl.wave_tile_n * ctrl.wave_step_n)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 2], {utility_log2(ctrl.wave_tile_n * ctrl.wave_step_n * k_pack)}, v[{v_gemm_in}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.waves_per_n())}, v[{v_thread_id}]") if ctrl.waves_per_m() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 3], {ctrl.waves_per_m() - 1}, v[{v_thread_id}] ; waves_per_m index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 3], {utility_log2(ctrl.wave_tile_m * ctrl.wave_step_m)}, v[{v_gemm_im}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 3], {utility_log2(ctrl.wave_tile_m * ctrl.wave_step_m * k_pack)}, v[{v_gemm_im}]") # self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.waves_per_n())}, v[{v_thread_id}]") self._emit_empty_line() return self._get_deferred() diff --git a/igemm/codegen/macro.py b/igemm/codegen/macro.py index 81163c2c..ec057c9c 100644 --- a/igemm/codegen/macro.py +++ b/igemm/codegen/macro.py @@ -36,6 +36,7 @@ def __init__(self, mc, inline = False): mc_base_t.__init__(self, mc) self.arg_list = list() self.inline = inline + self.expr_cnt = 0 def name(self): return 'n/a macro' def is_inline(self): @@ -71,10 +72,13 @@ def __call__(self, *args): setattr(self, self.arg_list[i], sym_t(args[i])) elif type(args[i]) is sym_t: setattr(self, self.arg_list[i], args[i]) + elif args[i] is None: + setattr(self, self.arg_list[i], None) # 2nd, do the emit with self._deferred_context(): self.expr() + self.expr_cnt += 1 # last, restore arg to default value. for a in self.arg_list: @@ -90,3 +94,4 @@ def emit(self): if not self.is_inline(): with self._emit_macro_indented(".macro {} {}".format(self.name(), ' '.join(self.arg_list))): self.expr() + self.expr_cnt += 1 diff --git a/igemm/codegen/mbb.py b/igemm/codegen/mbb.py index a3f83473..bedf58d0 100644 --- a/igemm/codegen/mbb.py +++ b/igemm/codegen/mbb.py @@ -74,6 +74,8 @@ def get_mc_inst_type(inst_str): return MC_INST_TYPE_SHARE_MEM if mc_inst_is_global_mem(inst_op): return MC_INST_TYPE_GLOBAL_MEM + if mc_inst_is_legacy_macro(inst_op): + return MC_INST_TYPE_LEGACY_MACRO return MC_INST_TYPE_OTHER class mc_inst_t(object): @@ -133,6 +135,17 @@ def dump(self): print(self()) print('-----------------------------------') +def machine_basic_block_call(p, mbb): + ''' + to pretty print mbb, the indent + currently p can not be mc_base_t directly. must be some child class + ''' + mbb_lines = mbb().split('\n') + with p._deferred_context(): + for line in mbb_lines: + p._emit(line) + return p._get_deferred() + def create_machine_basic_block(multi_line_inst_str, **option): ''' an post analysis and construction of mbb, only based on string parse. @@ -145,6 +158,7 @@ def create_machine_basic_block(multi_line_inst_str, **option): option: group_mbb_by_end_of_inst_op : str, group several mc_inst into mbb, each mbb is by end of this value + merge_mbb : int, do not split into multiple mbb ''' class parse_mbb_list_t(object): STATE_NORMAL = 0 @@ -183,6 +197,25 @@ def is_mbb_start_cmp_and_exec_block(self, current_index, istrs_list): return True return False return False + + def is_mbb_start_bfe_and_cmpx_block(self, current_index, istrs_list): + assert type(istrs_list) is list + current_istr = istrs_list[current_index] + # current_mc_inst = create_mc_inst(current_istr) + current_inst_op = get_mc_inst_op(current_istr) + if current_inst_op.startswith('v_bfe_u32'): + #print('asdadds XXXXX') + for next_index in range(current_index+1, len(istrs_list)): + next_istr = istrs_list[next_index] + next_mc_inst = create_mc_inst(next_istr) + next_inst_op = get_mc_inst_op(next_istr) + #print(f' next_inst_op:{next_inst_op} ') + if not next_mc_inst: + continue + if next_inst_op.startswith('v_cmp'): + return True + return False + return False def is_mbb_start(self, istr): _istr = istr.strip() @@ -206,6 +239,8 @@ def get_dict_with_default(dictionary, key, default_value): return dictionary[key] else: return default_value + + merge_mbb = get_dict_with_default(option, "merge_mbb", 0) group_mbb_by_end_of_inst_op = get_dict_with_default(option, "group_mbb_by_end_of_inst_op", "") def match_group_mbb_by_end_of_inst_op(inst_op): @@ -228,11 +263,17 @@ def match_group_mbb_by_end_of_inst_op(inst_op): if len(istrs) == 0: return None + for i, istr in enumerate(istrs): mc_inst = create_mc_inst(istr) if not mc_inst: continue + # merge every string into a single mbb + if merge_mbb: + mc_inst_buffer.append(mc_inst) + continue + # early pass rule if self.is_mbb_start_macro_c_clear(i, istrs): ''' @@ -265,7 +306,8 @@ def match_group_mbb_by_end_of_inst_op(inst_op): mc_inst_buffer.append(mc_inst) else: if state == self.STATE_NORMAL: - if self.is_mbb_start(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs): + if self.is_mbb_start(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs) \ + or self.is_mbb_start_bfe_and_cmpx_block(i, istrs): mc_inst_buffer.clear() mc_inst_buffer.append(mc_inst) state = self.STATE_PARSING_MBB @@ -273,7 +315,12 @@ def match_group_mbb_by_end_of_inst_op(inst_op): mbbs.append(machine_basic_block_t(copy.copy([mc_inst]))) else: if self.is_mbb_start(istr): - assert False, f'not support recursive start/end for now, with {i}:{istr}, {istrs}' + assert i > 1 + if self.is_mbb_start_bfe_and_cmpx_block(i - 1, istrs): + # TODO: this require bfe and cmpx have no other lines in between + pass + else: + assert False, f'not support recursive start/end for now, with {i}:{istr}, {istrs}' if self.is_mbb_end(istr): mc_inst_buffer.append(mc_inst) mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer))) diff --git a/igemm/codegen/mc.py b/igemm/codegen/mc.py index fcd9b36c..8fd6c396 100644 --- a/igemm/codegen/mc.py +++ b/igemm/codegen/mc.py @@ -29,6 +29,11 @@ from copy import deepcopy import subprocess +# NOTE: if following set to True, better parse '-V 0' to conv_driver +# since result can never be correct +MC_DEBUG_IGNORE_LDS_IO = False +MC_DEBUG_IGNORE_GLOBAL_IO = False + class mc_get_version_t(object): def __init__(self): self.called = 0 @@ -147,7 +152,24 @@ def __del__(self): self.close() def emit(self, s): if self.f: - self.f.write(self.indent() + s + '\n') + if MC_DEBUG_IGNORE_LDS_IO or MC_DEBUG_IGNORE_GLOBAL_IO: + s2 = s.split('\n') + ignore_list = list() + if MC_DEBUG_IGNORE_LDS_IO: + ignore_list.extend(['ds_read', 'ds_write', 's_barrier']) + # ignore_list.extend(['ds_write']) + if MC_DEBUG_IGNORE_GLOBAL_IO: + ignore_list.extend(['buffer_load', 's_waitcnt vmcnt']) + for iss, ss in enumerate(s2): + need_emit = True + for i in ignore_list: + if ss.strip().startswith(i): + need_emit = False + break + if need_emit: + self.f.write((self.indent() if iss == 0 else '') + ss + '\n') + else: + self.f.write(self.indent() + s + '\n') def emit_license(self): ''' diff --git a/igemm/codegen/scheduler.py b/igemm/codegen/scheduler.py index a06790b0..7e588b2f 100644 --- a/igemm/codegen/scheduler.py +++ b/igemm/codegen/scheduler.py @@ -96,7 +96,7 @@ def mbb_is_macro_c_clear(mbb): ''' if mbb.length() == 1: if mbb.mc_inst().type() == MC_INST_TYPE_LEGACY_MACRO: - if get_mc_inst_op(mbb.mc_inst()).startswith('.v_clear_nc'): + if get_mc_inst_op(mbb.mc_inst().inst_str).startswith('.v_clear_nc'): return True return False @@ -138,7 +138,7 @@ def mbb_is_macro_c_clear(mbb): #else: # break assert num_gmem != 0, f"no global mem in this instructino list, please check" - assert num_v_c_clear in (0, 1) + # assert num_v_c_clear in (0, 1) num_gmem += num_v_c_clear # second decide how many global mem to interleave per interval diff --git a/igemm/igemm_codegen_driver.py b/igemm/igemm_codegen_driver.py index 2d1d93c6..5cb8db11 100755 --- a/igemm/igemm_codegen_driver.py +++ b/igemm/igemm_codegen_driver.py @@ -49,7 +49,10 @@ def __init__(self, mc, tunable_dicts): for tdd in tunable_dicts: assert tdd['direction'] == 'fwd' # gtc fwd - kernel_list.extend([igemm_fwd_gtc_t(mc_asm_printer_t(mc.emitter, mc.arch_config), igemm_gtc_tunable_parameter_t(td)) for td in tunable_dicts]) + if 'tensor_layout' in tunable_dicts[0] and tunable_dicts[0]['tensor_layout'] == 'nhwc': + kernel_list.extend([igemm_fwd_gtc_nhwc_t(mc_asm_printer_t(mc.emitter, mc.arch_config), igemm_gtc_tunable_parameter_t(td)) for td in tunable_dicts]) + else: + kernel_list.extend([igemm_fwd_gtc_t(mc_asm_printer_t(mc.emitter, mc.arch_config), igemm_gtc_tunable_parameter_t(td)) for td in tunable_dicts]) elif tunable_dicts[0]['direction'] == 'bwd': for tdd in tunable_dicts: diff --git a/retieve_perf_data.py b/retieve_perf_data.py new file mode 100644 index 00000000..3b301470 --- /dev/null +++ b/retieve_perf_data.py @@ -0,0 +1,77 @@ +import re + +def get_log_lines(log_file_name): + with open(log_file_name, 'r') as f: + log_lines = f.readlines() + return log_lines + +def get_runtime(log_file_name): + txt_lines = get_log_lines(log_file_name) + min_cost = [] + min_kernel = [] + sel_kernel = [] + sel_costs = [] + + for each_line in txt_lines: + res_str = re.search(r'(?<=fastest:).*', each_line) + if res_str: + res_str = re.search(r'(?<=tflops:)\d+\.?\d*', each_line) + if res_str: + driver_cost = float(res_str.group()) + print(f"driver_cost={driver_cost}") + min_cost.append(driver_cost) + res_str = re.search(r'kernel_name:', each_line) + if res_str: + min_kernel.append(each_line.split(":")[-1][:-1]) + print(each_line.split(":")[-1][:-1]) + res_str = re.search(r'selected kernel:', each_line) + if res_str: + sel_kernel.append(each_line.split(":")[-1][:-1]) + print(each_line.split(":")[-1][:-1]) + + res_str = re.search(r'(?<=selected cost:)\d+\.?\d*', each_line) + if res_str: + sel_cost = float(res_str.group()) + print(f"driver_cost={sel_cost}") + sel_costs.append(sel_cost) + + with open("./wrw_model.csv", "w") as f: + for num_cost in min_cost: + f.write(f"{num_cost}") + f.write("\n") + for sel_cost in sel_costs: + f.write(f"{sel_cost}") + f.write("\n") + for kernel_name in min_kernel: + f.write(f"{kernel_name}") + f.write("\n") + for s_kernel_name in sel_kernel: + f.write(f"{s_kernel_name}") + f.write("\n") + +def get_kernel_name(log_file_name): + txt_lines = get_log_lines(log_file_name) + section_lines = [] + store_line = 0 + kernel_names = [] + for each_line in txt_lines: + conv_line = re.match(r"(?!#)./out/conv_driver.exe conv", each_line) + if store_line: + kernel_name = re.match(r"kernel:", each_line) + if kernel_name: + name = each_line.split(':')[-1] + print(f"\t\"{name[:-1]}\",") + kernel_names.append(name) + if conv_line: + store_line = 1 + conv_end_line = re.match(r"min cost:", each_line) + if conv_end_line: + break + + return kernel_names + + +if __name__ == '__main__': + #names = get_kernel_name("./wrw_model.log") + #print(len(names)) + get_runtime("./fwd_fp32_nhwc_models.log") diff --git a/script/gtc_conv_model.sh b/script/gtc_conv_model.sh index b2a7ffba..95e272f8 100755 --- a/script/gtc_conv_model.sh +++ b/script/gtc_conv_model.sh @@ -1,14 +1,33 @@ #!/bin/sh -if [ $# -ne 1 ] +if [ $# -lt 1 ] then - echo "please give this script a direction" - echo "now I use bwd as default" - DIR=bwd + DIR=bwd else DIR=$1 fi -export IGEMM_HSACO=out/igemm_${DIR}_gtc_gfx908.hsaco + +if [ $# -eq 2 ] +then + LAYOUT=$2 +else + LAYOUT="nchw" +fi + +if [ "${LAYOUT}" = "nchw" ] +then + LAYOUT_HSACO="" + LAYOUT_ARG="" +elif [ "${LAYOUT}" = "nhwc" ] +then + LAYOUT_HSACO="_nhwc" + LAYOUT_ARG="--in_layout NHWC --fil_layout NHWC --out_layout NHWC" +else + echo "wrong layout: ${LAYOUT}" + exit 1 +fi +echo IGEMM_HSACO=out/igemm_${DIR}_gtc_gfx908${LAYOUT_HSACO}.hsaco +export IGEMM_HSACO=out/igemm_${DIR}_gtc_gfx908${LAYOUT_HSACO}.hsaco export IGEMM_GPU_NAIVE_CONV_HSACO=out/naive_conv.hsaco export IGEMM_SCLK_MHZ=1283 export IGEMM_LOG_FASTEST_CONFIG=1 @@ -29,7 +48,8 @@ else fi # only forward support gemm_k_padding -if [ $FORW = 1 ] +#if [ $FORW = 1 ] +if [ 0 = 1 ] then ./out/conv_driver.exe conv -n 64 -c 3 -H 224 -W 224 -k 64 -y 7 -x 7 -p 3 -q 3 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ./out/conv_driver.exe conv -n 128 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW @@ -44,172 +64,171 @@ then ./out/conv_driver.exe conv -n 64 -c 512 -H 14 -W 14 -k 512 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 32 -F $FORW ./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 512 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 32 -F $FORW #exit 1 - fi #resnext101 -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -# ./out/conv_driver.exe conv -n 64 -c 3 -H 224 -W 224 -k 64 -y 7 -x 7 -p 3 -q 3 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 64 -c 3 -H 224 -W 224 -k 64 -y 7 -x 7 -p 3 -q 3 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} #inception4 batch_size=128 -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 32 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -# ./out/conv_driver.exe conv -n 128 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 448 -H 8 -W 8 -k 384 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 48 -H 35 -W 35 -k 64 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 64 -H 73 -W 73 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 160 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 80 -H 73 -W 73 -k 192 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 32 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 128 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 448 -H 8 -W 8 -k 384 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 48 -H 35 -W 35 -k 64 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 64 -H 73 -W 73 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 160 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 80 -H 73 -W 73 -k 192 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #inception3 batch_size=64 -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 160 -H 73 -W 73 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 35 -W 35 -k 224 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 71 -W 71 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 256 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 224 -H 35 -W 35 -k 256 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 256 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 320 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -# ./out/conv_driver.exe conv -n 64 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 320 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 96 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 448 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 448 -H 8 -W 8 -k 512 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 147 -W 147 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 96 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 160 -H 73 -W 73 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 35 -W 35 -k 224 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 71 -W 71 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 256 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 224 -H 35 -W 35 -k 256 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 256 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 320 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 64 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 320 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 96 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 448 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 448 -H 8 -W 8 -k 512 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 147 -W 147 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 96 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #resnet50 -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 128 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 128 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -# ./out/conv_driver.exe conv -n 64 -c 3 -H 230 -W 230 -k 64 -y 7 -x 7 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 256 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 512 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 128 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 128 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 64 -c 3 -H 230 -W 230 -k 64 -y 7 -x 7 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 256 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 512 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #from v4r1_origin_conv.sh -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1024 -H 17 -W 17 -k 1024 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 34 -W 34 -k 256 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 35 -W 35 -k 128 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 512 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 256 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 48 -H 7 -W 7 -k 128 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW +#./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 1024 -H 17 -W 17 -k 1024 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 64 -c 256 -H 34 -W 34 -k 256 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 128 -H 35 -W 35 -k 128 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 512 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 256 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 48 -H 7 -W 7 -k 128 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #mask rcnn -./out/conv_driver.exe conv -n 2 -c 256 -H 12 -W 18 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 1024 -H 34 -W 84 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 1024 -H 40 -W 52 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 256 -H 100 -W 104 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 256 -H 10 -W 20 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 64 -H 71 -W 83 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 64 -H 59 -W 57 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 4 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 4 -c 256 -H 28 -W 28 -k 256 -y 2 -x 2 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 3 -c 256 -H 28 -W 28 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 1 -c 256 -H 32 -W 64 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 1 -c 64 -H 17 -W 17 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 256 -H 12 -W 18 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 1024 -H 34 -W 84 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 1024 -H 40 -W 52 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 256 -H 100 -W 104 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 256 -H 10 -W 20 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 64 -H 71 -W 83 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 64 -H 59 -W 57 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 4 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 4 -c 256 -H 28 -W 28 -k 256 -y 2 -x 2 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 3 -c 256 -H 28 -W 28 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 1 -c 256 -H 32 -W 64 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 1 -c 64 -H 17 -W 17 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW #retina net bs=16 -./out/conv_driver.exe conv -n 16 -c 256 -H 12 -W 12 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 16 -c 256 -H 134 -W 77 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 16 -c 256 -H 71 -W 101 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 16 -c 256 -H 12 -W 12 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 16 -c 256 -H 134 -W 77 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 16 -c 256 -H 71 -W 101 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW diff --git a/test/inference/build.sh b/test/inference/build.sh new file mode 100644 index 00000000..11fa1612 --- /dev/null +++ b/test/inference/build.sh @@ -0,0 +1,14 @@ +#!/bin/sh +# to launch from top of generator +ARCH=gfx1030 +rm -rf out +mkdir out + +/opt/rocm/hip/bin/hipcc --amdgpu-target=$ARCH -Idriver -std=c++14 -lpthread test/inference/test_inference.cpp -o out/test_inference.exe || exit 1 +/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/fp16/ test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 +/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/int8/ test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm -o out/igemm_fwd_btm_nhwc_int8.hsaco || exit 1 +/opt/rocm/hip/bin/hipcc -x hip --cuda-gpu-arch=$ARCH --cuda-device-only -c -O3 driver/gpu_naive_conv/naive_conv.cpp -o out/naive_conv.hsaco + + + + diff --git a/test/inference/igemm_fwd_btm_nhwc.h b/test/inference/igemm_fwd_btm_nhwc.h new file mode 100644 index 00000000..6707f476 --- /dev/null +++ b/test/inference/igemm_fwd_btm_nhwc.h @@ -0,0 +1,341 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + *all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef __DIRECT_CONV_DRIVER_H +#define __DIRECT_CONV_DRIVER_H + + +#include +#include +#include +#include +#include + +#ifndef HIP_CALL +#define HIP_CALL(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + printf("[hiperror](%d) fail to call %s,(%s)", (int)err, #call, \ + hipGetErrorString(err)); \ + exit(1); \ + } \ + } while (0) +#endif + +static inline size_t gpu_conv_out_size(size_t in_size, size_t pad, + size_t dilation, size_t ksize, + size_t stride) { + return (in_size + 2 * pad - dilation * (ksize - 1) - 1) / stride + 1; +} + +typedef struct { + void * p_in; + void * p_wei; + void * p_out; + uint32_t hi; + uint32_t wi; + uint32_t n; + uint32_t k_per_group; + uint32_t c_per_group; + uint32_t ho; + uint32_t wo; + uint32_t sy; + uint32_t sx; + uint32_t dy; + uint32_t dx; + uint32_t py; + uint32_t px; + uint32_t fy; + uint32_t fx; + uint32_t group; + uint32_t batch_m; + uint32_t stride_m; + uint32_t magic_0; + uint32_t magic_1; + uint32_t magic_2; + uint32_t shift_pack_0; +} __attribute__((packed)) igemm_fwd_btm_2d_karg_t; +static inline void dump_igemm_fwd_btm_2d_karg(igemm_fwd_btm_2d_karg_t * karg) +{ + std::cout<<"p_in:"<p_in<<", "; + std::cout<<"p_wei:"<p_wei<<", "; + std::cout<<"p_out:"<p_out<<", "; + std::cout<<"hi:"<hi<<", "; + std::cout<<"wi:"<wi<<", "; + std::cout<<"n:"<n<<", "; + std::cout<<"k_per_group:"<k_per_group<<", "; + std::cout<<"c_per_group:"<c_per_group<<", "; + std::cout<<"ho:"<ho<<", "; + std::cout<<"wo:"<wo<<", "; + std::cout<<"sy:"<sy<<", "; + std::cout<<"sx:"<sx<<", "; + std::cout<<"dy:"<dy<<", "; + std::cout<<"dx:"<dx<<", "; + std::cout<<"py:"<py<<", "; + std::cout<<"px:"<px<<", "; + std::cout<<"fy:"<fy<<", "; + std::cout<<"fx:"<fx<<", "; + std::cout<<"group:"<group<<", "; + std::cout<<"batch_m:"<batch_m<<", "; + std::cout<<"stride_m:"<stride_m<<", "; + std::cout<<"magic_0:"<magic_0<<", "; + std::cout<<"magic_1:"<magic_1<<", "; + std::cout<<"magic_2:"<magic_2<<", "; + std::cout<<"shift_pack_0:"<shift_pack_0<= 1000) + num_cu *= 2; + } + ~igemm_fwd_btm_t(){} + std::string get_kernel_name(const igemm_fwd_btm_kernel_info_t *kernel_info) { + return kernel_info->kernel_name; + } + + bool is_valid(const args_t *arg, igemm_fwd_btm_kernel_info_t * kernel_info) + { + size_t hi = arg->get_int("in_h"); + size_t wi = arg->get_int("in_w"); + size_t n = arg->get_int("batchsize"); + size_t k = arg->get_int("out_channels"); + size_t c = arg->get_int("in_channels"); + + size_t sy = arg->get_int("conv_stride_h"); + size_t sx = arg->get_int("conv_stride_w"); + size_t dy = arg->get_int("dilation_h"); + size_t dx = arg->get_int("dilation_w"); + size_t py = arg->get_int("pad_h"); + size_t px = arg->get_int("pad_w"); + size_t fy = arg->get_int("fil_h"); + size_t fx = arg->get_int("fil_w"); + size_t ho = gpu_conv_out_size(hi, py, dy, fy, sy); + size_t wo = gpu_conv_out_size(wi, px, dx, fx, sx); + size_t group = arg->get_int("group_count"); + + assert(c % group == 0 && k % group == 0); + + assert(group != 0 && c % group == 0 && k % group == 0); + + size_t k_per_group = k / group; + size_t c_per_group = c / group; + + if(c_per_group != kernel_info->k_per_block) + return false; + + if(k_per_group % kernel_info->n_per_block != 0) + return false; + + return true; + } + + result_t run(const args_t *arg, hipModule_t module, igemm_fwd_btm_kernel_info_t * kernel_info, + void *p_in, void *p_wei, void *p_out, + int warmup, int repeat, const driverDataType_t& data_type) { + if(!is_valid(arg, kernel_info)){ + result_t result; + result.return_code = -1; + return result; + } + size_t hi = arg->get_int("in_h"); + size_t wi = arg->get_int("in_w"); + size_t n = arg->get_int("batchsize"); + size_t k = arg->get_int("out_channels"); + size_t c = arg->get_int("in_channels"); + + size_t sy = arg->get_int("conv_stride_h"); + size_t sx = arg->get_int("conv_stride_w"); + size_t dy = arg->get_int("dilation_h"); + size_t dx = arg->get_int("dilation_w"); + size_t py = arg->get_int("pad_h"); + size_t px = arg->get_int("pad_w"); + size_t fy = arg->get_int("fil_h"); + size_t fx = arg->get_int("fil_w"); + size_t ho = gpu_conv_out_size(hi, py, dy, fy, sy); + size_t wo = gpu_conv_out_size(wi, px, dx, fx, sx); + size_t group = arg->get_int("group_count"); + + assert(c % group == 0 && k % group == 0); + + assert(group != 0 && c % group == 0 && k % group == 0); + + size_t k_per_group = k / group; + size_t c_per_group = c / group; + igemm_fwd_btm_2d_karg_t karg; + karg.p_in = p_in; + karg.p_wei = p_wei; + karg.p_out = p_out; + karg.hi = static_cast(hi); + karg.wi = static_cast(wi); + karg.n = static_cast(n); + karg.k_per_group = static_cast(k_per_group); + karg.c_per_group = static_cast(c_per_group); + karg.ho = static_cast(ho); + karg.wo = static_cast(wo); + karg.sy = static_cast(sy); + karg.sx = static_cast(sx); + karg.dy = static_cast(dy); + karg.dx = static_cast(dx); + karg.py = static_cast(py); + karg.px = static_cast(px); + karg.fy = static_cast(fy); + karg.fx = static_cast(fx); + karg.group = static_cast(group); + size_t karg_size = sizeof(karg); + + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, + HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, + HIP_LAUNCH_PARAM_END}; + + hipFunction_t kernel_func; + HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_info->kernel_name.c_str())); + + int block_size = kernel_info->block_size; + int num_gemm_m = (ho * wo + kernel_info->m_per_block - 1) / kernel_info->m_per_block; + int num_gemm_n = (k_per_group + kernel_info->n_per_block - 1) / kernel_info->n_per_block; + + int grid_size = kernel_info->occupancy * num_cu; + grid_size = env_get_int("GRID_SIZE", grid_size); + if(grid_size % num_gemm_n == 0){ + int grids_for_m = grid_size / num_gemm_n; + karg.batch_m = (num_gemm_m + grids_for_m - 1) / grids_for_m; + karg.stride_m = kernel_info->m_per_block * grids_for_m; + + }else{ + grid_size = num_gemm_m * num_gemm_n; + karg.batch_m = 1; + karg.stride_m = 0; + } + + magic_div_u32_t mdiv_0 = magic_div_u32_gen(fx); + magic_div_u32_t mdiv_1 = magic_div_u32_gen(wo); + magic_div_u32_t mdiv_2 = magic_div_u32_gen(num_gemm_n); + karg.magic_0 = mdiv_0.magic; + karg.magic_1 = mdiv_1.magic; + karg.magic_2 = mdiv_2.magic; + karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, 0); + + // printf("launch fwd block:%d, grid:%d\n", block_size, grid_size); + // dump_igemm_fwd_btm_2d_karg(&karg); + + auto launch_fwd = [&]() -> float { + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, + HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, + HIP_LAUNCH_PARAM_END}; + float ms = .0; + + hipEvent_t start; + hipEvent_t stop; + hipEventCreate(&start); + hipEventCreate(&stop); + + // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, n, group, + block_size, 1, 1, 0, 0, NULL, + (void **)&config, start, stop)); + + hipEventSynchronize(stop); + hipEventElapsedTime(&ms, start, stop); + hipEventDestroy(start); + hipEventDestroy(stop); + + return ms; + }; + + for (int i = 0; i < warmup; i++) { + launch_fwd(); + } + + std::vector duration_list; + for (int i = 0; i < repeat; i++) { + float d = launch_fwd(); + duration_list.push_back(d); + } + + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + + usleep(1000 * 1); + + result_t result; + result.return_code = 0; + result.duration_ms = avg_duration; + result.kernel_name = kernel_info->kernel_name; + return result; + } +}; + +#endif diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm new file mode 100644 index 00000000..ff8d2218 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm @@ -0,0 +1,493 @@ +; pay attention to register bank of v_c, v_b +.macro .fma_1x16_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] + v_dot2c_f32_f16 v[\v_c+8 ], v[\v_a], v[\v_b+8 ] + v_dot2c_f32_f16 v[\v_c+9 ], v[\v_a], v[\v_b+9 ] + v_dot2c_f32_f16 v[\v_c+10], v[\v_a], v[\v_b+10] + v_dot2c_f32_f16 v[\v_c+11], v[\v_a], v[\v_b+11] + v_dot2c_f32_f16 v[\v_c+12], v[\v_a], v[\v_b+12] + v_dot2c_f32_f16 v[\v_c+13], v[\v_a], v[\v_b+13] + v_dot2c_f32_f16 v[\v_c+14], v[\v_a], v[\v_b+14] + v_dot2c_f32_f16 v[\v_c+15], v[\v_a], v[\v_b+15] +.endm + +.macro .fma_1x8_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] +.endm + +.macro .fma_1x4_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] +.endm + +.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp + s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer] + s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer] + s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift] +.endm + +.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp + .mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp + s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot] + s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp] +.endm + +.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp + v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer] + v_add_nc_u32 v[\v_tmp], v[\v_tmp], v[\v_numer] + v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp] +.endm + +.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp + .mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp + v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot] + v_sub_nc_u32 v[\v_rem], v[\v_numer], v[\v_tmp] +.endm + +.macro .v_clear_nc vid, num + _v = \vid + .rept \num + v_mov_b32 v[_v], 0 + _v = _v + 1 + .endr +.endm + +.include "igemm_fwd_btm_nhwc_fp16_128x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_128x016.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x016.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x008.asm" +.include "igemm_fwd_btm_nhwc_fp16_384x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_512x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_512x008.asm" +.include "igemm_fwd_btm_nhwc_fp16_1024x008.asm" + +.amdgpu_metadata +--- +amdhsa.version: [ 1, 0 ] +amdhsa.kernels: + - .name: igemm_fwd_btm_nhwc_fp16_128x4x16_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_128x4x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 88 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_128x16x16_r3 + .symbol: igemm_fwd_btm_nhwc_fp16_128x16x16_r3.kd + .sgpr_count: 60 + .vgpr_count: 74 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_256x4x16_r1.kd + .sgpr_count: 60 + .vgpr_count: 112 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x16x16_r3 + .symbol: igemm_fwd_btm_nhwc_fp16_256x16x16_r3.kd + .sgpr_count: 60 + .vgpr_count: 112 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x8x16_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_256x8x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 128 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x8x8_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_256x8x8_r2.kd + .sgpr_count: 64 + .vgpr_count: 124 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_384x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_384x4x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 114 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_512x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_512x4x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 140 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_512x8x16_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_512x8x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 188 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_512x8x8_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_512x8x8_r1.kd + .sgpr_count: 64 + .vgpr_count: 124 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_1024x8x8_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_1024x8x8_r1.kd + .sgpr_count: 64 + .vgpr_count: 212 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} +... +.end_amdgpu_metadata diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm new file mode 100644 index 00000000..d7b69e86 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm @@ -0,0 +1,1266 @@ +;---------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 64 +.set v_ax, 65 +.set v_ay, 97 +.set v_ib, 129 +.set v_b, 130 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 162 +.set v_in_ihi, 170 +.set v_in_iwi, 178 +.set v_in_flag, 186 +.set v_out_os, 194 +.set v_out_flag, 202 +.set v_tid, 210 +.set v_end, 212 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_1024x8x8_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_1024x8x8_r1,@function +igemm_fwd_btm_nhwc_fp16_1024x8x8_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp], v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 10 ; 1024 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ; s_lshl_b32 s[s_wei_offset], s[s_c], 4+1 ; 16x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ; v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + ;v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+4] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+4], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+5], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+6], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+7], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+4], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+4], 1, v[v_out_os+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+5], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+5], 1, v[v_out_os+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+6], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+6], 1, v[v_out_os+6] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + + v_mul_lo_u32 v[v_out_os+7], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+7], 1, v[v_out_os+7] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 64 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + + ;--- end move slice window + + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .v_clear_nc v_ay+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+24, 8 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ay+24:v_ay+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ay+28:v_ay+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +12, v_b + 0 + + .fma_1x8_fp16 v_c+32, v_ax +16, v_b + 0 + .fma_1x8_fp16 v_c+40, v_ax +20, v_b + 0 + .fma_1x8_fp16 v_c+48, v_ax +24, v_b + 0 + .fma_1x8_fp16 v_c+56, v_ax +28, v_b + 0 + + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +13, v_b + 8 + + .fma_1x8_fp16 v_c+32, v_ax +17, v_b + 8 + .fma_1x8_fp16 v_c+40, v_ax +21, v_b + 8 + .fma_1x8_fp16 v_c+48, v_ax +25, v_b + 8 + .fma_1x8_fp16 v_c+56, v_ax +29, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +14, v_b +16 + + .fma_1x8_fp16 v_c+32, v_ax +18, v_b +16 + .fma_1x8_fp16 v_c+40, v_ax +22, v_b +16 + .fma_1x8_fp16 v_c+48, v_ax +26, v_b +16 + .fma_1x8_fp16 v_c+56, v_ax +30, v_b +16 + + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +15, v_b +24 + + .fma_1x8_fp16 v_c+32, v_ax +19, v_b +24 + .fma_1x8_fp16 v_c+40, v_ax +23, v_b +24 + .fma_1x8_fp16 v_c+48, v_ax +27, v_b +24 + .fma_1x8_fp16 v_c+56, v_ax +31, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax +0:v_ax +3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .v_clear_nc v_ax+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+24, 8 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_fp16 v_c+32, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+40, v_ay +20, v_b + 0 + .fma_1x8_fp16 v_c+48, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+56, v_ay +28, v_b + 0 + + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + .fma_1x8_fp16 v_c+32, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+40, v_ay +21, v_b + 8 + .fma_1x8_fp16 v_c+48, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+56, v_ay +29, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_fp16 v_c+32, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+40, v_ay +22, v_b +16 + .fma_1x8_fp16 v_c+48, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+56, v_ay +30, v_b +16 + + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + .fma_1x8_fp16 v_c+32, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+40, v_ay +23, v_b +24 + .fma_1x8_fp16 v_c+48, v_ay +27, v_b +24 + .fma_1x8_fp16 v_c+56, v_ay +31, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + v_mov_b32 v[v_ay +24], v[v_ax +24] + v_mov_b32 v[v_ay +25], v[v_ax +25] + v_mov_b32 v[v_ay +26], v[v_ax +26] + v_mov_b32 v[v_ay +27], v[v_ax +27] + v_mov_b32 v[v_ay +28], v[v_ax +28] + v_mov_b32 v[v_ay +29], v[v_ax +29] + v_mov_b32 v[v_ay +30], v[v_ax +30] + v_mov_b32 v[v_ay +31], v[v_ax +31] + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_add_nc_u32 v[v_in_flag+4], s[s_ib_stride], v[v_in_flag+3] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_in_flag+4,s_magic_1,s_shift_m1,s_wo,v_in_os+4 + v_add_nc_u32 v[v_in_flag+5], s[s_ib_stride], v[v_in_flag+4] + .v_clear_nc v_ax+16, 4 + v_mul_u32_u24 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_in_flag+5,s_magic_1,s_shift_m1,s_wo,v_in_os+5 + v_add_nc_u32 v[v_in_flag+6], s[s_ib_stride], v[v_in_flag+5] + v_mul_u32_u24 v[v_in_os+4], s[s_wi], v[v_in_ihi+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_in_os+4], v[v_in_iwi+4], v[v_in_os+4] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_in_os+4] + + v_mul_u32_u24 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+5], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_in_os+5], v[v_in_iwi+5], v[v_in_os+5] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_in_os+5] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_in_flag+6,s_magic_1,s_shift_m1,s_wo,v_in_os+6 + v_add_nc_u32 v[v_in_flag+7], s[s_ib_stride], v[v_in_flag+6] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_in_flag+7,s_magic_1,s_shift_m1,s_wo,v_in_os+7 + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+6], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_in_os+6], v[v_in_iwi+6], v[v_in_os+6] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_in_os+6] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+7], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_in_os+7], v[v_in_iwi+7], v[v_in_os+7] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_in_os+7] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_fp16 v_c+32, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+40, v_ay +20, v_b + 0 + .fma_1x8_fp16 v_c+48, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+56, v_ay +28, v_b + 0 + + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + .fma_1x8_fp16 v_c+32, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+40, v_ay +21, v_b + 8 + .fma_1x8_fp16 v_c+48, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+56, v_ay +29, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_fp16 v_c+32, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+40, v_ay +22, v_b +16 + .fma_1x8_fp16 v_c+48, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+56, v_ay +30, v_b +16 + + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + .fma_1x8_fp16 v_c+32, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+40, v_ay +23, v_b +24 + .fma_1x8_fp16 v_c+48, v_ay +27, v_b +24 + .fma_1x8_fp16 v_c+56, v_ay +31, v_b +24 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + v_cvt_f16_f32 v[v_c +32], v[v_c +32] + v_cvt_f16_f32 v[v_c +33], v[v_c +33] + v_cvt_f16_f32 v[v_c +34], v[v_c +34] + v_cvt_f16_f32 v[v_c +35], v[v_c +35] + v_cvt_f16_f32 v[v_c +36], v[v_c +36] + v_cvt_f16_f32 v[v_c +37], v[v_c +37] + v_cvt_f16_f32 v[v_c +38], v[v_c +38] + v_cvt_f16_f32 v[v_c +39], v[v_c +39] + + v_cvt_f16_f32 v[v_c +40], v[v_c +40] + v_cvt_f16_f32 v[v_c +41], v[v_c +41] + v_cvt_f16_f32 v[v_c +42], v[v_c +42] + v_cvt_f16_f32 v[v_c +43], v[v_c +43] + v_cvt_f16_f32 v[v_c +44], v[v_c +44] + v_cvt_f16_f32 v[v_c +45], v[v_c +45] + v_cvt_f16_f32 v[v_c +46], v[v_c +46] + v_cvt_f16_f32 v[v_c +47], v[v_c +47] + + + v_pack_b32_f16 v[v_c_buf+16], v[v_c+32], v[v_c+33] + v_pack_b32_f16 v[v_c_buf+17], v[v_c+34], v[v_c+35] + v_pack_b32_f16 v[v_c_buf+18], v[v_c+36], v[v_c+37] + v_pack_b32_f16 v[v_c_buf+19], v[v_c+38], v[v_c+39] + + v_pack_b32_f16 v[v_c_buf+20], v[v_c+40], v[v_c+41] + v_pack_b32_f16 v[v_c_buf+21], v[v_c+42], v[v_c+43] + v_pack_b32_f16 v[v_c_buf+22], v[v_c+44], v[v_c+45] + v_pack_b32_f16 v[v_c_buf+23], v[v_c+46], v[v_c+47] + + v_cmpx_le_u32 1, v[v_out_flag+4] + global_store_dwordx4 v[v_out_os+4], v[v_c_buf+16:v_c_buf+19], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+5] + global_store_dwordx4 v[v_out_os+5], v[v_c_buf+20:v_c_buf+23], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +48], v[v_c +48] + v_cvt_f16_f32 v[v_c +49], v[v_c +49] + v_cvt_f16_f32 v[v_c +50], v[v_c +50] + v_cvt_f16_f32 v[v_c +51], v[v_c +51] + v_cvt_f16_f32 v[v_c +52], v[v_c +52] + v_cvt_f16_f32 v[v_c +53], v[v_c +53] + v_cvt_f16_f32 v[v_c +54], v[v_c +54] + v_cvt_f16_f32 v[v_c +55], v[v_c +55] + + v_cvt_f16_f32 v[v_c +56], v[v_c +56] + v_cvt_f16_f32 v[v_c +57], v[v_c +57] + v_cvt_f16_f32 v[v_c +58], v[v_c +58] + v_cvt_f16_f32 v[v_c +59], v[v_c +59] + v_cvt_f16_f32 v[v_c +60], v[v_c +60] + v_cvt_f16_f32 v[v_c +61], v[v_c +61] + v_cvt_f16_f32 v[v_c +62], v[v_c +62] + v_cvt_f16_f32 v[v_c +63], v[v_c +63] + + + v_pack_b32_f16 v[v_c_buf+24], v[v_c+48], v[v_c+49] + v_pack_b32_f16 v[v_c_buf+25], v[v_c+50], v[v_c+51] + v_pack_b32_f16 v[v_c_buf+26], v[v_c+52], v[v_c+53] + v_pack_b32_f16 v[v_c_buf+27], v[v_c+54], v[v_c+55] + + v_pack_b32_f16 v[v_c_buf+28], v[v_c+56], v[v_c+57] + v_pack_b32_f16 v[v_c_buf+29], v[v_c+58], v[v_c+59] + v_pack_b32_f16 v[v_c_buf+30], v[v_c+60], v[v_c+61] + v_pack_b32_f16 v[v_c_buf+31], v[v_c+62], v[v_c+63] + + v_cmpx_le_u32 1, v[v_out_flag+6] + global_store_dwordx4 v[v_out_os+6], v[v_c_buf+24:v_c_buf+27], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+7] + global_store_dwordx4 v[v_out_os+7], v[v_c_buf+28:v_c_buf+31], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 64 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + v_add_nc_u32 v[v_out_os+4], s[s_out_stride], v[v_out_os+4] + v_add_nc_u32 v[v_out_os+5], s[s_out_stride], v[v_out_os+5] + v_add_nc_u32 v[v_out_os+6], s[s_out_stride], v[v_out_os+6] + v_add_nc_u32 v[v_out_os+7], s[s_out_stride], v[v_out_os+7] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_1024x8x8_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 212 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm new file mode 100644 index 00000000..8b7df108 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm @@ -0,0 +1,666 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 8 +.set v_ax, 9 +.set v_ay, 25 +.set v_ib, 41 +.set v_b, 42 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 74 +.set v_in_ihi, 76 +.set v_in_iwi, 78 +.set v_in_flag, 80 +.set v_out_os, 82 +.set v_out_flag, 84 +.set v_tid, 86 +.set v_end, 88 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_128x4x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_128x4x16_r2,@function +igemm_fwd_btm_nhwc_fp16_128x4x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 64 + + ; calculate wei offset, 4x16, 4 for k, 16 for yxc, 8 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 15, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 7, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 7 ; 128 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + ; v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 8 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_body + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + ; v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 64 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_128x4x16_r2 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 88 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x016.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x016.asm new file mode 100644 index 00000000..055d0cd9 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x016.asm @@ -0,0 +1,584 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 + +.set s_block_ib, 2 ; bx, ho*wo +.set s_ka, 0 +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 39 +.set s_in_diff_hi, 40 +.set s_in_diff_wi, 41 +.set s_dilation_w_x, 42 +.set s_move_slice_k_ix, 43 + +.set s_kitr, 1 +.set s_wei_offset, 44 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 45 +.set s_br, 46 + +.set s_tmp, 48 +.set s_end, 54 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 16 +.set v_a, 17 +.set v_ib, 25 +.set v_b, 26 +.set v_gld_a, 58 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+12 +.set v_wei_ix_list, v_b+15 +.set v_wei_flag, v_b+18 +.set v_wei_os, v_b+21 +.set v_tmp, v_b+24 +.set v_wei_ik, v_a +.set v_wei_ic, v_a+1 +.set v_wei_ie, v_a+2 +.set v_wei_flag_ik, v_a+3 +.set v_sst_b_os, v_a+4 +.set v_in_os, 66 +.set v_in_ihi, 67 +.set v_in_iwi, 68 +.set v_in_flag, 69 +.set v_out_os, 70 +.set v_out_flag, 71 +.set v_tid, 72 +.set v_end, 74 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_128x16x16_r3 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_128x16x16_r3,@function +igemm_fwd_btm_nhwc_fp16_128x16x16_r3: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 4 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], 17*4 * 4 ; 17dword per row, 4 row + v_and_b32 v[v_tmp+5], 7, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + s_lshl_b32 s[s_block_ib], s[s_block_ib], 7 ; 128 half + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to 17 + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 3, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 2+1 ; 4x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+2], s[s_wei_offset], v[v_wei_os+1] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + .mdiv_u32_rem_vs v_wei_ix_list+2,v_wei_iy_list+2,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*17*4 ; stride for wei sst offset. 8 thread for k, each thread store 4 c, hence 8*4=32 + ; 17 dword per row + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + v_add_nc_u32 v[v_sst_b_os+2], s[s_tmp+5], v[v_sst_b_os+1] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + + s_mov_b32 s[s_sld_b_stride], 17*8*4 + + s_waitcnt vmcnt(2) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+8], v[v_gld_b+9], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+10], v[v_gld_b+11], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + s_cmp_gt_i32 s[s_kitr], 0 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end + +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body: + ; accumulate im + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi], s[s_tmp], v[v_in_iwi] + v_add_nc_u32 v[v_in_os], s[s_tmp+1], v[v_in_os] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi], s[s_dilation_h], v[v_in_ihi] +igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + ;--- end move slice window + + s_waitcnt vmcnt(0) + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + .v_clear_nc v_gld_a, 8 + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body + +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end: + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + s_waitcnt vmcnt(0) + + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + s_mov_b32 s[s_move_slice_k_ix], 0 + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + s_waitcnt lgkmcnt(2) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + + .v_clear_nc v_c, 16 + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_end: + s_endpgm + +; LDS: (16+1) * (3 * 4) * 16 * 4 = 13056 +; k pad r3 e:4 c dword +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_128x16x16_r3 + .amdhsa_group_segment_fixed_size 13056 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 74 + .amdhsa_next_free_sgpr 54 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm new file mode 100644 index 00000000..f304a586 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm @@ -0,0 +1,652 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 8 +.set v_ax, 9 +.set v_ay, 25 +.set v_ib, 41 +.set v_b, 42 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 74 +.set v_in_ihi, 76 +.set v_in_iwi, 78 +.set v_in_flag, 80 +.set v_out_os, 82 +.set v_out_flag, 84 +.set v_tid, 86 +.set v_end, 88 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x4x16_r1,@function +igemm_fwd_btm_nhwc_fp16_256x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 4x32, 4 for k, 32 for yxc, 16 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 5, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 31, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 15, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ; s_mov_b32 s[s_tmp+5], 128*k_n_dword*4 ; stride for wei sst offset. 32 thread for gemm_k, each thread store 4 c, hence 32*4=128 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + ;v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + ; v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 8 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + ; v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x4x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 88 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm new file mode 100644 index 00000000..0b880322 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm @@ -0,0 +1,1520 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 16 +.set v_ax, 17 +.set v_ay, 33 +.set v_ib, 49 +.set v_b, 50 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 114 +.set v_in_ihi, 116 +.set v_in_iwi, 118 +.set v_in_flag, 120 +.set v_out_os, 122 +.set v_out_flag, 124 +.set v_tid, 126 +.set v_end, 128 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x8x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x8x16_r2,@function +igemm_fwd_btm_nhwc_fp16_256x8x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 8 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 15, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 7, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax +11, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ax +12, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ax + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ax +13, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ax + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ax +14, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ax + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ax +15, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_body + + +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + s_mov_b32 s[s_move_slice_k_ix], 0 + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(8) + + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + + + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x8x16_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 128 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel + + + + + +;---------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 98 +.set v_in_ihi, 102 +.set v_in_iwi, 106 +.set v_in_flag, 110 +.set v_out_os, 114 +.set v_out_flag, 118 +.set v_tid, 122 +.set v_end, 124 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x8x8_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x8x8_r2,@function +igemm_fwd_btm_nhwc_fp16_256x8x8_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 64 + + ; calculate wei offset, 8x8, 8 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp], v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax +0:v_ax +3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_body + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x8x8_r2 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 124 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x016.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x016.asm new file mode 100644 index 00000000..2725dc4a --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x016.asm @@ -0,0 +1,734 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 + +.set s_block_ib, 2 ; bx, ho*wo +.set s_ka, 0 +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 39 +.set s_in_diff_hi, 40 +.set s_in_diff_wi, 41 +.set s_dilation_w_x, 42 +.set s_move_slice_k_ix, 43 + +.set s_kitr, 1 +.set s_wei_offset, 44 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 45 +.set s_br, 46 +.set s_ib_stride, 47 + +.set s_tmp, 48 +.set s_end, 54 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_a, 33 +.set v_ib, 49 +.set v_b, 50 +.set v_gld_a, 82 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+12 +.set v_wei_ix_list, v_b+15 +.set v_wei_flag, v_b+18 +.set v_wei_os, v_b+21 +.set v_tmp, v_b+24 +.set v_wei_ik, v_a +.set v_wei_ic, v_a+1 +.set v_wei_ie, v_a+2 +.set v_wei_flag_ik, v_a+3 +.set v_sst_b_os, v_a+4 +.set v_in_os, 98 +.set v_in_ihi, 100 +.set v_in_iwi, 102 +.set v_in_flag, 104 +.set v_out_os, 106 +.set v_out_flag, 108 +.set v_tid, 100 +.set v_end, 112 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x16x16_r3 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x16x16_r3,@function +igemm_fwd_btm_nhwc_fp16_256x16x16_r3: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 4 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], 17*4 * 4 ; 17dword per row, 4 row + v_and_b32 v[v_tmp+5], 7, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + s_lshl_b32 s[s_block_ib], s[s_block_ib], 8 ; 256 half + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to 17 + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 3, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 2+1 ; 4x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+2], s[s_wei_offset], v[v_wei_os+1] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + .mdiv_u32_rem_vs v_wei_ix_list+2,v_wei_iy_list+2,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*17*4 ; stride for wei sst offset. 8 thread for k, each thread store 4 c, hence 8*4=32 + ; 17 dword per row + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + v_add_nc_u32 v[v_sst_b_os+2], s[s_tmp+5], v[v_sst_b_os+1] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_gld_a+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_gld_a+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], 17*8*4 + + s_waitcnt vmcnt(2) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+8], v[v_gld_b+9], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+10], v[v_gld_b+11], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + s_cmp_gt_i32 s[s_kitr], 0 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body: + ; accumulate im + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + s_waitcnt vmcnt(0) + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + v_mov_b32 v[v_a + 8], v[v_gld_a + 8] + v_mov_b32 v[v_a + 9], v[v_gld_a + 9] + v_mov_b32 v[v_a +10], v[v_gld_a +10] + v_mov_b32 v[v_a +11], v[v_gld_a +11] + v_mov_b32 v[v_a +12], v[v_gld_a +12] + v_mov_b32 v[v_a +13], v[v_gld_a +13] + v_mov_b32 v[v_a +14], v[v_gld_a +14] + v_mov_b32 v[v_a +15], v[v_gld_a +15] + .v_clear_nc v_gld_a, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a + 8, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a + 8, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + .fma_1x8_fp16 v_c+16, v_a + 9, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + .fma_1x8_fp16 v_c+24, v_a + 9, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a +10, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +10, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + .fma_1x8_fp16 v_c+16, v_a +11, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + .fma_1x8_fp16 v_c+24, v_a +11, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a +12, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +12, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + .fma_1x8_fp16 v_c+16, v_a +13, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + .fma_1x8_fp16 v_c+24, v_a +13, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a +14, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +14, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + .fma_1x8_fp16 v_c+16, v_a +15, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + .fma_1x8_fp16 v_c+24, v_a +15, v_b +24 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body + +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end: + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + s_waitcnt vmcnt(0) + + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + v_mov_b32 v[v_a + 8], v[v_gld_a + 8] + v_mov_b32 v[v_a + 9], v[v_gld_a + 9] + v_mov_b32 v[v_a +10], v[v_gld_a +10] + v_mov_b32 v[v_a +11], v[v_gld_a +11] + v_mov_b32 v[v_a +12], v[v_gld_a +12] + v_mov_b32 v[v_a +13], v[v_gld_a +13] + v_mov_b32 v[v_a +14], v[v_gld_a +14] + v_mov_b32 v[v_a +15], v[v_gld_a +15] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_gld_a+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_gld_a+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + s_mov_b32 s[s_move_slice_k_ix], 0 + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a + 8, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a + 8, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + .fma_1x8_fp16 v_c+16, v_a + 9, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + .fma_1x8_fp16 v_c+24, v_a + 9, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a +10, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +10, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + .fma_1x8_fp16 v_c+16, v_a +11, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + .fma_1x8_fp16 v_c+24, v_a +11, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a +12, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +12, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + .fma_1x8_fp16 v_c+16, v_a +13, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + .fma_1x8_fp16 v_c+24, v_a +13, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + .fma_1x8_fp16 v_c+16, v_a +14, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +14, v_b + 8 + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + s_waitcnt lgkmcnt(2) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + .fma_1x8_fp16 v_c+16, v_a +15, v_b +16 + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + .fma_1x8_fp16 v_c+24, v_a +15, v_b +24 + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + s_mov_b64 exec, -1 + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1], offset:16 + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + + .v_clear_nc v_c, 32 + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end: + s_endpgm + +; LDS: (16+1) * (3 * 4) * 16 * 4 = 13056 +; k pad r3 e:4 c dword +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x16x16_r3 + .amdhsa_group_segment_fixed_size 13056 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 112 + .amdhsa_next_free_sgpr 54 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm new file mode 100644 index 00000000..09afaa7b --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm @@ -0,0 +1,779 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 12 +.set v_ax, 13 +.set v_ay, 37 +.set v_ib, 61 +.set v_b, 62 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 94 +.set v_in_ihi, 97 +.set v_in_iwi, 100 +.set v_in_flag, 103 +.set v_out_os, 106 +.set v_out_flag, 109 +.set v_tid, 112 +.set v_end, 114 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_384x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_384x4x16_r1,@function +igemm_fwd_btm_nhwc_fp16_384x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 4x32, 4 for k, 32 for yxc, 16 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 5, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 31, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 15, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + s_mov_b32 s[s_block_ib], 384 + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_mul_i32 s[s_block_ib], s[s_tmp+5], s[s_block_ib] ; 384 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ; s_mov_b32 s[s_tmp+5], 128*k_n_dword*4 ; stride for wei sst offset. 32 thread for gemm_k, each thread store 4 c, hence 32*4=128 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(6) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 12 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] +igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay+ 0, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+ 8, 8 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_waitcnt vmcnt(6) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ax +16, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ax +17, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ax +18, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ax +19, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ax +20, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ax +21, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ax +22, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ax +23, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] +igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + ;--- end move slice window + + .v_clear_nc v_ax+ 0, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+ 8, 8 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(6) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + ; v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx2 v[v_out_os+2], v[v_c_buf+4:v_c_buf+5], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 12 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_384x4x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 114 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm new file mode 100644 index 00000000..1b2cc4b7 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm @@ -0,0 +1,886 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 16 +.set v_ax, 17 +.set v_ay, 49 +.set v_ib, 81 +.set v_b, 82 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 114 +.set v_in_ihi, 118 +.set v_in_iwi, 122 +.set v_in_flag, 126 +.set v_out_os, 130 +.set v_out_flag, 134 +.set v_tid, 138 +.set v_end, 140 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_512x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_512x4x16_r1,@function +igemm_fwd_btm_nhwc_fp16_512x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 4x32, 4 for k, 32 for yxc, 16 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 5, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 31, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 15, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ; s_mov_b32 s[s_tmp+5], 128*k_n_dword*4 ; stride for wei sst offset. 32 thread for gemm_k, each thread store 4 c, hence 32*4=128 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+24:v_ay+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+28:v_ay+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ax +16, v_b + 0 + .fma_1x4_fp16 v_c+12, v_ax +24, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ax +17, v_b + 4 + .fma_1x4_fp16 v_c+12, v_ax +25, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ax +18, v_b + 8 + .fma_1x4_fp16 v_c+12, v_ax +26, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ax +19, v_b +12 + .fma_1x4_fp16 v_c+12, v_ax +27, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ax +20, v_b +16 + .fma_1x4_fp16 v_c+12, v_ax +28, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ax +21, v_b +20 + .fma_1x4_fp16 v_c+12, v_ax +29, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ax +22, v_b +24 + .fma_1x4_fp16 v_c+12, v_ax +30, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ax +23, v_b +28 + .fma_1x4_fp16 v_c+12, v_ax +31, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + .fma_1x4_fp16 v_c+12, v_ay +24, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + .fma_1x4_fp16 v_c+12, v_ay +25, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + .fma_1x4_fp16 v_c+12, v_ay +26, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + .fma_1x4_fp16 v_c+12, v_ay +27, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + .fma_1x4_fp16 v_c+12, v_ay +28, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + .fma_1x4_fp16 v_c+12, v_ay +29, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + .fma_1x4_fp16 v_c+12, v_ay +30, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + .fma_1x4_fp16 v_c+12, v_ay +31, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + v_mov_b32 v[v_ay +24], v[v_ax +24] + v_mov_b32 v[v_ay +25], v[v_ax +25] + v_mov_b32 v[v_ay +26], v[v_ax +26] + v_mov_b32 v[v_ay +27], v[v_ax +27] + v_mov_b32 v[v_ay +28], v[v_ax +28] + v_mov_b32 v[v_ay +29], v[v_ax +29] + v_mov_b32 v[v_ay +30], v[v_ax +30] + v_mov_b32 v[v_ay +31], v[v_ax +31] +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + .fma_1x4_fp16 v_c+12, v_ay +24, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + .fma_1x4_fp16 v_c+12, v_ay +25, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + .fma_1x4_fp16 v_c+12, v_ay +26, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + .fma_1x4_fp16 v_c+12, v_ay +27, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + .fma_1x4_fp16 v_c+12, v_ay +28, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + .fma_1x4_fp16 v_c+12, v_ay +29, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + .fma_1x4_fp16 v_c+12, v_ay +30, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + .fma_1x4_fp16 v_c+12, v_ay +31, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx2 v[v_out_os+2], v[v_c_buf+4:v_c_buf+5], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx2 v[v_out_os+3], v[v_c_buf+6:v_c_buf+7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_512x4x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 140 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm new file mode 100644 index 00000000..ce44c39c --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm @@ -0,0 +1,1756 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 65 +.set v_ib, 97 +.set v_b, 98 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 162 +.set v_in_ihi, 166 +.set v_in_iwi, 170 +.set v_in_flag, 174 +.set v_out_os, 178 +.set v_out_flag, 182 +.set v_tid, 186 +.set v_end, 188 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_512x8x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_512x8x16_r2,@function +igemm_fwd_btm_nhwc_fp16_512x8x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 8 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 15, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 7, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+24:v_ay+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+28:v_ay+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax +16, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +24, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax +17, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +25, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +18, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +26, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +19, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +27, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ax +12, v_b +32 + .fma_1x8_fp16 v_c+16, v_ax +20, v_b +32 + .fma_1x8_fp16 v_c+24, v_ax +28, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ax + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ax +13, v_b +40 + .fma_1x8_fp16 v_c+16, v_ax +21, v_b +40 + .fma_1x8_fp16 v_c+24, v_ax +29, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ax + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ax +14, v_b +48 + .fma_1x8_fp16 v_c+16, v_ax +22, v_b +48 + .fma_1x8_fp16 v_c+24, v_ax +30, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ax + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ax +15, v_b +56 + .fma_1x8_fp16 v_c+16, v_ax +23, v_b +56 + .fma_1x8_fp16 v_c+24, v_ax +31, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +27, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+16, v_ay +20, v_b +32 + .fma_1x8_fp16 v_c+24, v_ay +28, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+16, v_ay +21, v_b +40 + .fma_1x8_fp16 v_c+24, v_ay +29, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+16, v_ay +22, v_b +48 + .fma_1x8_fp16 v_c+24, v_ay +30, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + .fma_1x8_fp16 v_c+16, v_ay +23, v_b +56 + .fma_1x8_fp16 v_c+24, v_ay +31, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_body + +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + v_mov_b32 v[v_ay +24], v[v_ax +24] + v_mov_b32 v[v_ay +25], v[v_ax +25] + v_mov_b32 v[v_ay +26], v[v_ax +26] + v_mov_b32 v[v_ay +27], v[v_ax +27] + v_mov_b32 v[v_ay +28], v[v_ax +28] + v_mov_b32 v[v_ay +29], v[v_ax +29] + v_mov_b32 v[v_ay +30], v[v_ax +30] + v_mov_b32 v[v_ay +31], v[v_ax +31] +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(8) + + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +27, v_b +24 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+16, v_ay +20, v_b +32 + .fma_1x8_fp16 v_c+24, v_ay +28, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+16, v_ay +21, v_b +40 + .fma_1x8_fp16 v_c+24, v_ay +29, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+16, v_ay +22, v_b +48 + .fma_1x8_fp16 v_c+24, v_ay +30, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + .fma_1x8_fp16 v_c+16, v_ay +23, v_b +56 + .fma_1x8_fp16 v_c+24, v_ay +31, v_b +56 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_512x8x16_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 188 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel + + +;---------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 98 +.set v_in_ihi, 102 +.set v_in_iwi, 106 +.set v_in_flag, 110 +.set v_out_os, 114 +.set v_out_flag, 118 +.set v_tid, 122 +.set v_end, 124 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_512x8x8_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_512x8x8_r1,@function +igemm_fwd_btm_nhwc_fp16_512x8x8_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp], v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ; s_lshl_b32 s[s_wei_offset], s[s_c], 4+1 ; 16x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ; v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + ;v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax +0:v_ax +3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_512x8x8_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 124 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm new file mode 100644 index 00000000..d8549814 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm @@ -0,0 +1,424 @@ +; pay attention to register bank of v_c, v_b +.macro .fma_1x16_int8x4 v_c, v_a, v_b + v_dot4c_i32_i8 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot4c_i32_i8 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot4c_i32_i8 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot4c_i32_i8 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot4c_i32_i8 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot4c_i32_i8 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot4c_i32_i8 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot4c_i32_i8 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] + v_dot4c_i32_i8 v[\v_c+8 ], v[\v_a], v[\v_b+8 ] + v_dot4c_i32_i8 v[\v_c+9 ], v[\v_a], v[\v_b+9 ] + v_dot4c_i32_i8 v[\v_c+10], v[\v_a], v[\v_b+10] + v_dot4c_i32_i8 v[\v_c+11], v[\v_a], v[\v_b+11] + v_dot4c_i32_i8 v[\v_c+12], v[\v_a], v[\v_b+12] + v_dot4c_i32_i8 v[\v_c+13], v[\v_a], v[\v_b+13] + v_dot4c_i32_i8 v[\v_c+14], v[\v_a], v[\v_b+14] + v_dot4c_i32_i8 v[\v_c+15], v[\v_a], v[\v_b+15] +.endm + +.macro .fma_1x8_int8x4 v_c, v_a, v_b + v_dot4c_i32_i8 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot4c_i32_i8 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot4c_i32_i8 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot4c_i32_i8 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot4c_i32_i8 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot4c_i32_i8 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot4c_i32_i8 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot4c_i32_i8 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] +.endm + +.macro .fma_1x4_int8x4 v_c, v_a, v_b + v_dot4c_i32_i8 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot4c_i32_i8 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot4c_i32_i8 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot4c_i32_i8 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] +.endm + +.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp + s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer] + s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer] + s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift] +.endm + +.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp + .mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp + s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot] + s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp] +.endm + +.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp + v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer] + v_add_nc_u32 v[\v_tmp], v[\v_tmp], v[\v_numer] + v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp] +.endm + +.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp + .mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp + v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot] + v_sub_nc_u32 v[\v_rem], v[\v_numer], v[\v_tmp] +.endm + +.macro .pack_i8x4_i32_r1 v_d, v_src, s_0xff + v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] + v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] + v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] + v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] + v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] + v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] + v_or_b32 v[\v_d], v[\v_src+ 0], v[\v_src+ 3] + v_or3_b32 v[\v_d], v[\v_d], v[\v_src+ 1], v[\v_src+ 2] +.endm + +.macro .pack_i8x4_i32_r2 v_d, v_src, s_0xff + v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] + v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] + + v_and_b32 v[\v_src+ 4], s[\s_0xff], v[\v_src+ 4] + v_lshlrev_b32 v[\v_src+ 7], 24, v[\v_src+ 7] + + v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] + v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] + + v_or_b32 v[\v_d+ 0], v[\v_src+ 0], v[\v_src+ 3] + + v_and_b32 v[\v_src+ 5], s[\s_0xff], v[\v_src+ 5] + + v_and_b32 v[\v_src+ 6], s[\s_0xff], v[\v_src+ 6] + v_or_b32 v[\v_d+ 1], v[\v_src+ 4], v[\v_src+ 7] + + v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] + v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] + v_lshlrev_b32 v[\v_src+ 5], 8, v[\v_src+ 5] + v_lshlrev_b32 v[\v_src+ 6], 16, v[\v_src+ 6] + + v_or3_b32 v[\v_d+ 0], v[\v_d+ 0], v[\v_src+ 1], v[\v_src+ 2] + v_or3_b32 v[\v_d+ 1], v[\v_d+ 1], v[\v_src+ 5], v[\v_src+ 6] +.endm + +;.macro .pack_i8x4_i32_r4 v_d, v_src, s_0xff +; v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] +; v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] +; v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] +; v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] +; v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] +; v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] +; v_or_b32 v[\v_d+ 0], v[\v_src+ 0], v[\v_src+ 3] +; v_or3_b32 v[\v_d+ 0], v[\v_d+ 0], v[\v_src+ 1], v[\v_src+ 2] +; +; v_and_b32 v[\v_src+ 4], s[\s_0xff], v[\v_src+ 4] +; v_and_b32 v[\v_src+ 5], s[\s_0xff], v[\v_src+ 5] +; v_and_b32 v[\v_src+ 6], s[\s_0xff], v[\v_src+ 6] +; v_lshlrev_b32 v[\v_src+ 7], 24, v[\v_src+ 7] +; v_lshlrev_b32 v[\v_src+ 5], 8, v[\v_src+ 5] +; v_lshlrev_b32 v[\v_src+ 6], 16, v[\v_src+ 6] +; v_or_b32 v[\v_d+ 1], v[\v_src+ 4], v[\v_src+ 7] +; v_or3_b32 v[\v_d+ 1], v[\v_d+ 1], v[\v_src+ 5], v[\v_src+ 6] +; +; v_and_b32 v[\v_src+ 8], s[\s_0xff], v[\v_src+ 8] +; v_and_b32 v[\v_src+ 9], s[\s_0xff], v[\v_src+ 9] +; v_and_b32 v[\v_src+10], s[\s_0xff], v[\v_src+10] +; v_lshlrev_b32 v[\v_src+11], 24, v[\v_src+11] +; v_lshlrev_b32 v[\v_src+ 9], 8, v[\v_src+ 9] +; v_lshlrev_b32 v[\v_src+10], 16, v[\v_src+10] +; v_or_b32 v[\v_d+ 2], v[\v_src+ 8], v[\v_src+11] +; v_or3_b32 v[\v_d+ 2], v[\v_d+ 2], v[\v_src+ 9], v[\v_src+10] +; +; v_and_b32 v[\v_src+12], s[\s_0xff], v[\v_src+12] +; v_and_b32 v[\v_src+13], s[\s_0xff], v[\v_src+13] +; v_and_b32 v[\v_src+14], s[\s_0xff], v[\v_src+14] +; v_lshlrev_b32 v[\v_src+15], 24, v[\v_src+15] +; v_lshlrev_b32 v[\v_src+13], 8, v[\v_src+13] +; v_lshlrev_b32 v[\v_src+14], 16, v[\v_src+14] +; v_or_b32 v[\v_d+ 3], v[\v_src+12], v[\v_src+15] +; v_or3_b32 v[\v_d+ 3], v[\v_d+ 3], v[\v_src+13], v[\v_src+14] +;.endm + +.macro .pack_i8x4_i32_r4 v_d, v_src, s_0xff + v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] + v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] + v_and_b32 v[\v_src+ 4], s[\s_0xff], v[\v_src+ 4] + v_lshlrev_b32 v[\v_src+ 7], 24, v[\v_src+ 7] + + v_and_b32 v[\v_src+ 8], s[\s_0xff], v[\v_src+ 8] + v_lshlrev_b32 v[\v_src+11], 24, v[\v_src+11] + v_and_b32 v[\v_src+12], s[\s_0xff], v[\v_src+12] + v_lshlrev_b32 v[\v_src+15], 24, v[\v_src+15] + + v_or_b32 v[\v_d+ 0], v[\v_src+ 0], v[\v_src+ 3] + v_or_b32 v[\v_d+ 1], v[\v_src+ 4], v[\v_src+ 7] + v_or_b32 v[\v_d+ 2], v[\v_src+ 8], v[\v_src+11] + + v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] + v_or_b32 v[\v_d+ 3], v[\v_src+12], v[\v_src+15] + + v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] + v_and_b32 v[\v_src+ 5], s[\s_0xff], v[\v_src+ 5] + v_and_b32 v[\v_src+ 6], s[\s_0xff], v[\v_src+ 6] + v_and_b32 v[\v_src+ 9], s[\s_0xff], v[\v_src+ 9] + v_and_b32 v[\v_src+10], s[\s_0xff], v[\v_src+10] + v_and_b32 v[\v_src+13], s[\s_0xff], v[\v_src+13] + v_and_b32 v[\v_src+14], s[\s_0xff], v[\v_src+14] + + v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] + v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] + + v_lshlrev_b32 v[\v_src+ 5], 8, v[\v_src+ 5] + v_lshlrev_b32 v[\v_src+ 6], 16, v[\v_src+ 6] + + v_lshlrev_b32 v[\v_src+ 9], 8, v[\v_src+ 9] + v_lshlrev_b32 v[\v_src+10], 16, v[\v_src+10] + + v_lshlrev_b32 v[\v_src+13], 8, v[\v_src+13] + v_lshlrev_b32 v[\v_src+14], 16, v[\v_src+14] + + v_or3_b32 v[\v_d+ 0], v[\v_d+ 0], v[\v_src+ 1], v[\v_src+ 2] + v_or3_b32 v[\v_d+ 1], v[\v_d+ 1], v[\v_src+ 5], v[\v_src+ 6] + v_or3_b32 v[\v_d+ 2], v[\v_d+ 2], v[\v_src+ 9], v[\v_src+10] + v_or3_b32 v[\v_d+ 3], v[\v_d+ 3], v[\v_src+13], v[\v_src+14] +.endm + + +.macro .v_clear_nc vid, num + _v = \vid + .rept \num + v_mov_b32 v[_v], 0 + _v = _v + 1 + .endr +.endm + +.include "igemm_fwd_btm_nhwc_int8_256x004.asm" +.include "igemm_fwd_btm_nhwc_int8_256x008.asm" +.include "igemm_fwd_btm_nhwc_int8_512x008.asm" +.include "igemm_fwd_btm_nhwc_int8_512x016.asm" +.include "igemm_fwd_btm_nhwc_int8_1024x016.asm" + +.amdgpu_metadata +--- +amdhsa.version: [ 1, 0 ] +amdhsa.kernels: + - .name: igemm_fwd_btm_nhwc_int8_256x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_int8_256x4x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 108 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 1024 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_256x8x16_r1 + .symbol: igemm_fwd_btm_nhwc_int8_256x8x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 80 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_512x8x16_r1 + .symbol: igemm_fwd_btm_nhwc_int8_512x8x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 124 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_512x16x8_r2 + .symbol: igemm_fwd_btm_nhwc_int8_512x16x8_r2.kd + .sgpr_count: 64 + .vgpr_count: 140 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_512x16x16_r2 + .symbol: igemm_fwd_btm_nhwc_int8_512x16x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 188 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_1024x16x8_r2 + .symbol: igemm_fwd_btm_nhwc_int8_1024x16x8_r2.kd + .sgpr_count: 64 + .vgpr_count: 244 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} +... +.end_amdgpu_metadata diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm new file mode 100644 index 00000000..c5bf9282 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm @@ -0,0 +1,1081 @@ +;---------------------------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 16 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 128 +.set v_ax, 129 +.set v_ay, 145 +.set v_ib, 161 +.set v_b, 162 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 194 +.set v_in_ihi, 202 +.set v_in_iwi, 210 +.set v_in_flag, 218 +.set v_out_os, 226 +.set v_out_flag, 234 +.set v_tid, 242 +.set v_end, 244 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_1024x16x8_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_1024x16x8_r2,@function +igemm_fwd_btm_nhwc_int8_1024x16x8_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 2 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 4 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 10 ; 1024 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 4 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+0 ; 8x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx2 v[v_gld_b+0:v_gld_b+1], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx2 v[v_gld_b+2:v_gld_b+3], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 16*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 2 c, hence 8*2=16 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + ;.v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + ;.v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+4], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ax+ 8:v_ax+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+5], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ax+10:v_ax+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+6], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ax+12:v_ax+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+7], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ax+14:v_ax+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+4], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+5], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+6], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + + v_mul_lo_u32 v[v_out_os+7], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*2 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 128 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ay+ 0:v_ay+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ay+ 2:v_ay+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ay+ 4:v_ay+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ay+ 6:v_ay+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 4 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ay+ 8:v_ay+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ay+10:v_ay+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+12, 4 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ay+12:v_ay+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ay+14:v_ay+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x16_int8x4 v_c+ 16, v_ax + 2, v_b + 0 + .fma_1x16_int8x4 v_c+ 32, v_ax + 4, v_b + 0 + .fma_1x16_int8x4 v_c+ 48, v_ax + 6, v_b + 0 + .fma_1x16_int8x4 v_c+ 64, v_ax + 8, v_b + 0 + .fma_1x16_int8x4 v_c+ 80, v_ax +10, v_b + 0 + .fma_1x16_int8x4 v_c+ 96, v_ax +12, v_b + 0 + .fma_1x16_int8x4 v_c+112, v_ax +14, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 1, v_b +16 + .fma_1x16_int8x4 v_c+ 16, v_ax + 3, v_b +16 + .fma_1x16_int8x4 v_c+ 32, v_ax + 5, v_b +16 + .fma_1x16_int8x4 v_c+ 48, v_ax + 7, v_b +16 + .fma_1x16_int8x4 v_c+ 64, v_ax + 9, v_b +16 + .fma_1x16_int8x4 v_c+ 80, v_ax +11, v_b +16 + .fma_1x16_int8x4 v_c+ 96, v_ax +13, v_b +16 + .fma_1x16_int8x4 v_c+112, v_ax +15, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + ;--- end move slice window + + .v_clear_nc v_ax, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 4 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ax+ 8:v_ax+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ax+10:v_ax+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+12, 4 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ax+12:v_ax+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ax+14:v_ax+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+ 16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+ 32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+ 48, v_ay + 6, v_b + 0 + .fma_1x16_int8x4 v_c+ 64, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+ 80, v_ay +10, v_b + 0 + .fma_1x16_int8x4 v_c+ 96, v_ay +12, v_b + 0 + .fma_1x16_int8x4 v_c+112, v_ay +14, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+ 16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+ 32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+ 48, v_ay + 7, v_b +16 + .fma_1x16_int8x4 v_c+ 64, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+ 80, v_ay +11, v_b +16 + .fma_1x16_int8x4 v_c+ 96, v_ay +13, v_b +16 + .fma_1x16_int8x4 v_c+112, v_ay +15, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_body + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 2 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+2, 2 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 2 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+6, 2 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_add_nc_u32 v[v_in_flag+4], s[s_ib_stride], v[v_in_flag+3] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_in_flag+4,s_magic_1,s_shift_m1,s_wo,v_in_os+4 + v_add_nc_u32 v[v_in_flag+5], s[s_ib_stride], v[v_in_flag+4] + v_mul_lo_u32 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + .v_clear_nc v_ax+8, 2 + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + .v_clear_nc v_ax+10, 2 + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_in_flag+5,s_magic_1,s_shift_m1,s_wo,v_in_os+5 + v_add_nc_u32 v[v_in_flag+6], s[s_ib_stride], v[v_in_flag+5] + v_mul_lo_u32 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+4], s[s_wi], v[v_in_ihi+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_in_os+4], v[v_in_iwi+4], v[v_in_os+4] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_in_os+4] + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ax+ 8:v_ax+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+5], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_in_os+5], v[v_in_iwi+5], v[v_in_os+5] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_in_os+5] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ax+10:v_ax+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_in_flag+6,s_magic_1,s_shift_m1,s_wo,v_in_os+6 + v_add_nc_u32 v[v_in_flag+7], s[s_ib_stride], v[v_in_flag+6] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + .v_clear_nc v_ax+12, 2 + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + .v_clear_nc v_ax+14, 2 + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_in_flag+7,s_magic_1,s_shift_m1,s_wo,v_in_os+7 + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+6], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_in_os+6], v[v_in_iwi+6], v[v_in_os+6] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_in_os+6] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ax+12:v_ax+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+7], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_in_os+7], v[v_in_iwi+7], v[v_in_os+7] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_in_os+7] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ax+14:v_ax+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+ 16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+ 32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+ 48, v_ay + 6, v_b + 0 + .fma_1x16_int8x4 v_c+ 64, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+ 80, v_ay +10, v_b + 0 + .fma_1x16_int8x4 v_c+ 96, v_ay +12, v_b + 0 + .fma_1x16_int8x4 v_c+112, v_ay +14, v_b + 0 + + s_waitcnt lgkmcnt(0) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+ 16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+ 32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+ 48, v_ay + 7, v_b +16 + .fma_1x16_int8x4 v_c+ 64, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+ 80, v_ay +11, v_b +16 + .fma_1x16_int8x4 v_c+ 96, v_ay +13, v_b +16 + .fma_1x16_int8x4 v_c+112, v_ay +15, v_b +16 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 8, v_c+32, s_0xff + .pack_i8x4_i32_r4 v_c_buf+12, v_c+48, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + .pack_i8x4_i32_r4 v_c_buf+16, v_c+64, s_0xff + .pack_i8x4_i32_r4 v_c_buf+20, v_c+80, s_0xff + v_cmpx_le_u32 1, v[v_out_flag+4] + global_store_dwordx4 v[v_out_os+4], v[v_c_buf+16:v_c_buf+19], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+5] + global_store_dwordx4 v[v_out_os+5], v[v_c_buf+20:v_c_buf+23], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+24, v_c+96, s_0xff + .pack_i8x4_i32_r4 v_c_buf+28, v_c+112, s_0xff + v_cmpx_le_u32 1, v[v_out_flag+6] + global_store_dwordx4 v[v_out_os+6], v[v_c_buf+24:v_c_buf+27], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+7] + global_store_dwordx4 v[v_out_os+7], v[v_c_buf+28:v_c_buf+31], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + .v_clear_nc v_c, 128 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + v_add_nc_u32 v[v_out_os+4], s[s_out_stride], v[v_out_os+4] + v_add_nc_u32 v[v_out_os+5], s[s_out_stride], v[v_out_os+5] + v_add_nc_u32 v[v_out_os+6], s[s_out_stride], v[v_out_os+6] + v_add_nc_u32 v[v_out_os+7], s[s_out_stride], v[v_out_os+7] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_body +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_1024x16x8_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 244 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm new file mode 100644 index 00000000..01566d15 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm @@ -0,0 +1,738 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 82 +.set v_in_ihi, 86 +.set v_in_iwi, 90 +.set v_in_flag, 94 +.set v_out_os, 98 +.set v_out_flag, 102 +.set v_tid, 106 +.set v_end, 108 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_256x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_256x4x16_r1,@function +igemm_fwd_btm_nhwc_int8_256x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 64 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 4x16, 4 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ;s_lshl_b32 s[s_wei_offset], s[s_c], 4+0 ; 16x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_int8x4 v_c+ 4, v_ax + 4, v_b + 0 + .fma_1x4_int8x4 v_c+ 8, v_ax + 8, v_b + 0 + .fma_1x4_int8x4 v_c+12, v_ax +12, v_b + 0 + + .fma_1x4_int8x4 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_int8x4 v_c+ 4, v_ax + 5, v_b + 4 + .fma_1x4_int8x4 v_c+ 8, v_ax + 9, v_b + 4 + .fma_1x4_int8x4 v_c+12, v_ax +13, v_b + 4 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + + s_waitcnt lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_int8x4 v_c+ 4, v_ax + 6, v_b + 8 + .fma_1x4_int8x4 v_c+ 8, v_ax +10, v_b + 8 + .fma_1x4_int8x4 v_c+12, v_ax +14, v_b + 8 + + .fma_1x4_int8x4 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_int8x4 v_c+ 4, v_ax + 7, v_b +12 + .fma_1x4_int8x4 v_c+ 8, v_ax +11, v_b +12 + .fma_1x4_int8x4 v_c+12, v_ax +15, v_b +12 + + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_int8x4 v_c+ 4, v_ay + 4, v_b + 0 + .fma_1x4_int8x4 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x4_int8x4 v_c+12, v_ay +12, v_b + 0 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_int8x4 v_c+ 4, v_ay + 5, v_b + 4 + .fma_1x4_int8x4 v_c+ 8, v_ay + 9, v_b + 4 + .fma_1x4_int8x4 v_c+12, v_ay +13, v_b + 4 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + + s_waitcnt lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_int8x4 v_c+ 4, v_ay + 6, v_b + 8 + .fma_1x4_int8x4 v_c+ 8, v_ay +10, v_b + 8 + .fma_1x4_int8x4 v_c+12, v_ay +14, v_b + 8 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_int8x4 v_c+ 4, v_ay + 7, v_b +12 + .fma_1x4_int8x4 v_c+ 8, v_ay +11, v_b +12 + .fma_1x4_int8x4 v_c+12, v_ay +15, v_b +12 + + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt vmcnt(4) lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_int8x4 v_c+ 4, v_ay + 4, v_b + 0 + .fma_1x4_int8x4 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x4_int8x4 v_c+12, v_ay +12, v_b + 0 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_int8x4 v_c+ 4, v_ay + 5, v_b + 4 + .fma_1x4_int8x4 v_c+ 8, v_ay + 9, v_b + 4 + .fma_1x4_int8x4 v_c+12, v_ay +13, v_b + 4 + + s_waitcnt lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_int8x4 v_c+ 4, v_ay + 6, v_b + 8 + .fma_1x4_int8x4 v_c+ 8, v_ay +10, v_b + 8 + .fma_1x4_int8x4 v_c+12, v_ay +14, v_b + 8 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_int8x4 v_c+ 4, v_ay + 7, v_b +12 + .fma_1x4_int8x4 v_c+ 8, v_ay +11, v_b +12 + .fma_1x4_int8x4 v_c+12, v_ay +15, v_b +12 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dword v[v_out_os], v[v_c_buf+0], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dword v[v_out_os+1], v[v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dword v[v_out_os+2], v[v_c_buf+2], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dword v[v_out_os+3], v[v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 64 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_256x4x16_r1 + .amdhsa_group_segment_fixed_size 1024 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 108 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm new file mode 100644 index 00000000..b34ac185 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm @@ -0,0 +1,585 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 16 +.set v_ax, 17 +.set v_ay, 25 +.set v_ib, 33 +.set v_b, 34 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 66 +.set v_in_ihi, 68 +.set v_in_iwi, 70 +.set v_in_flag, 72 +.set v_out_os, 74 +.set v_out_flag, 76 +.set v_tid, 78 +.set v_end, 80 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_256x8x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_256x8x16_r1,@function +igemm_fwd_btm_nhwc_int8_256x8x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ;s_lshl_b32 s[s_wei_offset], s[s_c], 4+0 ; 16x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + ;v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + ;v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(2) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(2) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ax + 4, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ax + 5, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ax + 6, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ax + 7, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(2) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + ;v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+ 2:v_c_buf+ 3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_body +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_256x8x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 80 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm new file mode 100644 index 00000000..4866b14c --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm @@ -0,0 +1,757 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 98 +.set v_in_ihi, 102 +.set v_in_iwi, 106 +.set v_in_flag, 110 +.set v_out_os, 114 +.set v_out_flag, 118 +.set v_tid, 122 +.set v_end, 124 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_512x8x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_512x8x16_r1,@function +igemm_fwd_btm_nhwc_int8_512x8x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ;s_lshl_b32 s[s_wei_offset], s[s_c], 4+0 ; 16x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + ;v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_int8x4 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_int8x4 v_c+24, v_ax +12, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_int8x4 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_int8x4 v_c+24, v_ax +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_int8x4 v_c+16, v_ax +10, v_b +16 + .fma_1x8_int8x4 v_c+24, v_ax +14, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_int8x4 v_c+16, v_ax +11, v_b +24 + .fma_1x8_int8x4 v_c+24, v_ax +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_int8x4 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_int8x4 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_int8x4 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_int8x4 v_c+24, v_ay +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_int8x4 v_c+16, v_ay +10, v_b +16 + .fma_1x8_int8x4 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_int8x4 v_c+16, v_ay +11, v_b +24 + .fma_1x8_int8x4 v_c+24, v_ay +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_int8x4 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_int8x4 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_int8x4 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_int8x4 v_c+24, v_ay +13, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_int8x4 v_c+16, v_ay +10, v_b +16 + .fma_1x8_int8x4 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_int8x4 v_c+16, v_ay +11, v_b +24 + .fma_1x8_int8x4 v_c+24, v_ay +15, v_b +24 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+ 2:v_c_buf+ 3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx2 v[v_out_os+2], v[v_c_buf+ 4:v_c_buf+ 5], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx2 v[v_out_os+3], v[v_c_buf+ 6:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_body +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_512x8x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 124 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm new file mode 100644 index 00000000..79371df7 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm @@ -0,0 +1,1545 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 16 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 64 +.set v_ax, 65 +.set v_ay, 81 +.set v_ib, 97 +.set v_b, 98 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 162 +.set v_in_ihi, 166 +.set v_in_iwi, 170 +.set v_in_flag, 174 +.set v_out_os, 178 +.set v_out_flag, 182 +.set v_tid, 186 +.set v_end, 188 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_512x16x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_512x16x16_r2,@function +igemm_fwd_btm_nhwc_int8_512x16x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 4 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 4 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+0 ; 8x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + ;v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 64 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ax + 4, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ax + 8, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ax +12, v_b + 0 + + .fma_1x16_int8x4 v_c+ 0, v_ax + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ax + 5, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ax + 9, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ax +13, v_b +16 + + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_waitcnt lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ax + 2, v_b +32 + .fma_1x16_int8x4 v_c+16, v_ax + 6, v_b +32 + .fma_1x16_int8x4 v_c+32, v_ax +10, v_b +32 + .fma_1x16_int8x4 v_c+48, v_ax +14, v_b +32 + + .fma_1x16_int8x4 v_c+ 0, v_ax + 3, v_b +48 + .fma_1x16_int8x4 v_c+16, v_ax + 7, v_b +48 + .fma_1x16_int8x4 v_c+32, v_ax +11, v_b +48 + .fma_1x16_int8x4 v_c+48, v_ax +15, v_b +48 + + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay +12, v_b + 0 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay +13, v_b +16 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_waitcnt lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ay + 2, v_b +32 + .fma_1x16_int8x4 v_c+16, v_ay + 6, v_b +32 + .fma_1x16_int8x4 v_c+32, v_ay +10, v_b +32 + .fma_1x16_int8x4 v_c+48, v_ay +14, v_b +32 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 3, v_b +48 + .fma_1x16_int8x4 v_c+16, v_ay + 7, v_b +48 + .fma_1x16_int8x4 v_c+32, v_ay +11, v_b +48 + .fma_1x16_int8x4 v_c+48, v_ay +15, v_b +48 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_body + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay +12, v_b + 0 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay +13, v_b +16 + + s_waitcnt lgkmcnt(0) + .fma_1x16_int8x4 v_c+ 0, v_ay + 2, v_b +32 + .fma_1x16_int8x4 v_c+16, v_ay + 6, v_b +32 + .fma_1x16_int8x4 v_c+32, v_ay +10, v_b +32 + .fma_1x16_int8x4 v_c+48, v_ay +14, v_b +32 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 3, v_b +48 + .fma_1x16_int8x4 v_c+16, v_ay + 7, v_b +48 + .fma_1x16_int8x4 v_c+32, v_ay +11, v_b +48 + .fma_1x16_int8x4 v_c+48, v_ay +15, v_b +48 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 8, v_c+32, s_0xff + .pack_i8x4_i32_r4 v_c_buf+12, v_c+48, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + .v_clear_nc v_c, 64 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_body +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_512x16x16_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 188 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel + +;---------------------------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 16 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 64 +.set v_ax, 65 +.set v_ay, 73 +.set v_ib, 81 +.set v_b, 82 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 114 +.set v_in_ihi, 118 +.set v_in_iwi, 122 +.set v_in_flag, 126 +.set v_out_os, 130 +.set v_out_flag, 134 +.set v_tid, 138 +.set v_end, 140 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_512x16x8_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_512x16x8_r2,@function +igemm_fwd_btm_nhwc_int8_512x16x8_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 2 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 4 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 4 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+0 ; 8x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx2 v[v_gld_b+0:v_gld_b+1], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx2 v[v_gld_b+2:v_gld_b+3], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 16*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 2 c, hence 8*2=16 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + ;.v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + ;.v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + ;v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*2 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 64 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ay+ 0:v_ay+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ay+ 2:v_ay+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ay+ 4:v_ay+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ay+ 6:v_ay+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ax + 2, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ax + 4, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ax + 6, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ax + 3, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ax + 5, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ax + 7, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay + 6, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay + 7, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_body + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 2 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+2, 2 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 2 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+6, 2 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay + 6, v_b + 0 + + s_waitcnt lgkmcnt(0) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay + 7, v_b +16 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 8, v_c+32, s_0xff + .pack_i8x4_i32_r4 v_c_buf+12, v_c+48, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + .v_clear_nc v_c, 64 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_body +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_512x16x8_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 140 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/test_inference.cpp b/test/inference/test_inference.cpp new file mode 100644 index 00000000..17a5b067 --- /dev/null +++ b/test/inference/test_inference.cpp @@ -0,0 +1,630 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "args.h" + + +#include "half.hpp" +using float16 = half_float::half; + +std::string parse_base_arg(int argc, char* argv[]) +{ + if(argc < 2) + { + printf("Invalid Number of Input Arguments\n"); + exit(0); + } + + std::string arg = argv[1]; + + if(arg != "conv" && arg != "convfp16" && arg != "convint8" && arg != "--version") + { + printf("Invalid Base Input Argument\n"); + exit(0); + } + else if(arg == "-h" || arg == "--help" || arg == "-?") + exit(0); + else + return arg; +} + +static inline size_t conv_out_size(size_t in_size, size_t pad, size_t dilation, + size_t ksize, size_t stride) { + return (in_size + 2 * pad - dilation * (ksize - 1) - 1) / stride + 1; +} +typedef struct { + uint32_t magic; + uint8_t shift; +} magic_div_u32_t; +static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) { + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for (shift = 0; shift < 32; shift++) + if ((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; +} +static inline uint32_t magic_div_u32_pack_shift(uint8_t s0, uint8_t s1, uint8_t s2, uint8_t s3) +{ + uint32_t shift_0 = static_cast(s0); + uint32_t shift_1 = static_cast(s1); + uint32_t shift_2 = static_cast(s2); + uint32_t shift_3 = static_cast(s3); + return (shift_3 << 24) | (shift_2 << 16) | (shift_1 << 8) | shift_0; +} +typedef struct { + int return_code; + float duration_ms; + float gflops; + float efficiency; + std::string kernel_name; +} result_t; + + +typedef enum { + driverHalf = 0, /*!< 16-bit floating point (Fully supported) */ + driverFloat = 1, /*!< 32-bit floating point (Fully supported) */ + driverInt8 = 3, + driverBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction) + (Partially supported) */ +} driverDataType_t; + +static inline size_t get_data_byte(driverDataType_t dtype) +{ + if(dtype == driverHalf) + return 2; + if(dtype == driverFloat) + return 4; + if(dtype == driverInt8) + return 1; + if(dtype == driverBFloat16) + return 2; + assert(0); + return 0; +} + +static inline int env_get_int(const char *var_name, int default_int) { + char *v = getenv(var_name); + int r = default_int; + if (v) + r = atoi(v); + return r; +} + +#define NAIVE_CONV_THREADED +#include "naive_conv.h" +#include "gpu_naive_conv.h" +#include "igemm_fwd_btm_nhwc.h" + + +#define HIP_CALL(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + printf("[hiperror](%d) fail to call %s,(%s)", (int)err, #call, \ + hipGetErrorString(err)); \ + exit(1); \ + } \ + } while (0) + +static int gen_rand_integer() +{ + static int inited = 0; + if(inited == 0) + { + std::srand(std::time(nullptr)); + inited = 1; + } + return std::rand(); +} + + +static inline char *env_get_str(const char *var_name, char *default_str) { + char *v = getenv(var_name); + if (v) + return v; + return default_str; +} + +template +void block_wise_tensor_copy(Dst_T *p_dst, Src_T *p_src, int tid, int block_size, int total_size) +{ + for (int i = tid; i < total_size; i += block_size) { + p_dst[i] = static_cast(p_src[i]); + } +} + +template +void tensor_copy(Dst_T *p_dst, Src_T *p_src, size_t tensor_size) { + int num_threads = std::thread::hardware_concurrency(); + if (num_threads < 4) + num_threads = 4; + + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.push_back(std::thread(block_wise_tensor_copy, + p_dst, p_src, t, num_threads, tensor_size)); + } + for (auto &th : threads) + th.join(); +} + +template +struct distribution_t{ +}; + +template <> +struct distribution_t{ + distribution_t(int min, int max) : distribution(min, max) {} + template + int operator()(URNG & rng){ return distribution(rng);} + std::uniform_int_distribution distribution; +}; +template <> +struct distribution_t{ + distribution_t(float min, float max) : distribution(min, max) {} + template + float operator()(URNG & rng){ return distribution(rng);} + std::uniform_real_distribution distribution; +}; + +template +void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale) +{ + std::mt19937 rng(std::chrono::system_clock::now() + .time_since_epoch() + .count() + + std::hash()(std::this_thread::get_id())); + distribution_t distribution(min,max); + for (int i = tid; i < total_size; i += block_size) { + p[i] = static_cast(scale * distribution(rng)); + } +} + +template +void gen_rand_vector(Dst_T *vec, size_t vec_size, Src_T fmin, Src_T fmax, Src_T scale = 1) { + int num_threads = std::thread::hardware_concurrency(); + if (num_threads < 4) + num_threads = 4; + // printf("total threads:%d\n",num_threads); + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.push_back(std::thread(block_wise_rand_generator, + vec, t, num_threads, vec_size, fmin, fmax, scale)); + } + for (auto &th : threads) + th.join(); +} + +template +bool valid_float(T p) +{ + return !(std::isnan(p) || std::isinf(p)); +} + +template<> +bool valid_float(int8_t p) +{ + // there is no meaning to valid integer number + return true; +} + +#ifndef ABS +#define ABS(b) ((b) > 0 ? (b) : -1 * (b)) +#endif +template +bool valid_vector(const float *ref, const T *pred, size_t n, + double nrms = 1e-6) { + double s0 = 0.0; + double s1 = 0.0; + int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); + int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + size_t pp_err = 0; + + for (size_t i = 0; i < n; ++i) { + double ri = (double)ref[i]; + double pi = (double)pred[i]; + if(!(valid_float(ref[i]) && valid_float(pred[i]))){ + printf(" invalid float at %4zu, ref:%f, pred:%f\n", i, ri, pi); + return false; + } + double d = ri - pi; + double dd = d * d; + double rr = 2.0 * ri * ri; + s0 += dd; + s1 += rr; + if(igemm_per_pixel_check){ + double delta = ABS(ABS(ri - pi) / ri); + printf("[%zu] ref:%lf, pred:%lf(0x%08x) [%s]\n", i, ri, pi, *(uint32_t*)(&pred[i]), delta > 3e-5? "N":"Y"); + if (delta > 3e-5) { + if(igemm_per_pixel_check_print){ + if (pp_err < 100) + printf("diff at %zu, ref:%lf, pred:%lf(0x%08x), d:%lf\n", i, ri, + pi, *(uint32_t*)(&pred[i]), delta); + } + pp_err++; + } + } + } + // printf("\nnrms:%lf, s0:%lf, s1:%lf, expected_nrms is %1f\n",sqrt(s0/s1),s0,s1,nrms); + return (sqrt(s0 / s1) < nrms) +#ifdef PER_PIXEL_CHECK + && (pp_err == 0) +#endif + ; +} + +template<> +bool valid_vector(const float *ref, const int8_t *pred, size_t n, + double nrms) { + // int8 valid, we prefer a per pixel match + int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); + int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + size_t pp_err = 0; + + for (size_t i = 0; i < n; ++i) { + if(!(valid_float(ref[i]) ) ){ + printf(" invalid float at %4zu, ref:%f\n", i, ref[i]); + return false; + } + int8_t pi = pred[i]; + int32_t ri = static_cast(ref[i]); + int8_t ri_clamp; + memcpy(&ri_clamp, &ri, 1); + + if(igemm_per_pixel_check){ + printf("[%zu] ref:%d(%d), pred:%d(0x%08x) [%s]\n", i, ri, ri_clamp, pi, + *(uint32_t*)(&pred[i]), pi != ri_clamp ? "N":"Y"); + } + + if(pi != ri_clamp){ + pp_err++; + } + } + return pp_err == 0; +} + +static inline void dump_output_dword(const float *out, size_t n) +{ + for (size_t i = 0; i < n; ++i) { + double pi = (double)out[i]; + printf("[%zu] pred:%lf(0x%08x)\n", i, pi, ((uint32_t *)out)[i]); + } +} + +static inline double theoritical_gflops(double sclk_ghz, size_t cu, + size_t simd) { + return 2 * sclk_ghz * cu * simd; +} + +static inline double +theoritical_conv_flop(size_t n, size_t c, size_t hi, size_t wi, size_t k, + size_t y, size_t x, size_t stride_h, size_t stride_w, + size_t dilation_h, size_t dilation_w, size_t pad_h, + size_t pad_w, size_t ngroups) { + size_t ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + size_t wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + double flop = (double)n * c * ho * wo * k * y * x * 2 / ngroups; + return flop; +} +static inline double +measured_conv_gflops(double time_ms, size_t n, size_t c, size_t hi, + size_t wi, size_t k, size_t y, size_t x, + size_t stride_h, size_t stride_w, size_t dilation_h, + size_t dilation_w, size_t pad_h, size_t pad_w, size_t ngroups) { + double flop = + theoritical_conv_flop(n, c, hi, wi, k, y, x, stride_h, stride_w, + dilation_h, dilation_w, pad_h, pad_w, ngroups); + return flop / (time_ms * 1e6); +} + +static inline double get_nrms(int forw, driverDataType_t driver_data_type){ + auto basic_tolerance = [=]() -> double{ + if (driver_data_type == driverFloat){ +#ifdef USE_XDNN + return 5e-5; +#else + return 1.5e-6; +#endif + } + else if (driver_data_type == driverHalf){ +#ifdef USE_XDNN + return 5*8.2e-3; +#else + return 8.2e-3; +#endif + } + }; + double nrms = basic_tolerance(); + // wrw has a high tolerance + if (forw == 4){ + nrms *= 2; + if(driver_data_type == driverFloat){ + nrms = 0.01; + } + else if(driver_data_type == driverHalf){ + nrms *= 5; + } + } + return nrms; +} + +#define GPU_NAIVE_CONV_HSACO "naive_conv.hsaco" +#define SCLK_MHZ 2200 +#define WARMUP 3 +#define REPEAT 8 + +#ifndef HSACO +#define HSACO "igemm_fwd_btm_nhwc_fp16.hsaco" +#endif +int main(int argc, char **argv){ + int warmup = env_get_int("WARMUP", WARMUP); + int repeat = env_get_int("REPEAT", REPEAT); + int sclk_mhz = env_get_int("SCLK_MHZ", SCLK_MHZ); + int dump_out = env_get_int("DUMP_OUT", 0); + + char *gpu_naive_conv_hsaco = env_get_str("GPU_NAIVE_CONV_HSACO", GPU_NAIVE_CONV_HSACO); + gpu_naive_conv_init(gpu_naive_conv_hsaco); + + std::string base_arg = parse_base_arg(argc, argv); + std::string default_hsaco = "igemm_fwd_btm_nhwc_"; + + driverDataType_t driver_data_type; + int fp_factor = 1; + if(base_arg == "conv"){ + driver_data_type = driverFloat; + default_hsaco += "fp32.hsaco"; + } + else if(base_arg == "convfp16"){ + driver_data_type = driverHalf; + default_hsaco += "fp16.hsaco"; + fp_factor = 2; + } + else if(base_arg == "convbf16") { + driver_data_type = driverBFloat16; + exit(0); + } + else if(base_arg == "convint8") { + driver_data_type = driverInt8; + default_hsaco += "int8.hsaco"; + fp_factor = 4; + } + else + exit(0); + + size_t data_byte = get_data_byte(driver_data_type); + char *hsaco = env_get_str("HSACO", const_cast(default_hsaco.c_str())); + + hipModule_t module; + HIP_CALL(hipModuleLoad(&module, hsaco)); + + args_t conv_args = create_conv_args(argc, argv); + // dump_arg(&conv_args); + + int hi = conv_args.get_int("in_h"); + int wi = conv_args.get_int("in_w"); + int n = conv_args.get_int("batchsize"); + int k = conv_args.get_int("out_channels"); + int c = conv_args.get_int("in_channels"); + + int stride_h = conv_args.get_int("conv_stride_h"); + int stride_w = conv_args.get_int("conv_stride_w"); + int dilation_h = conv_args.get_int("dilation_h"); + int dilation_w = conv_args.get_int("dilation_w"); + int pad_h = conv_args.get_int("pad_h"); + int pad_w = conv_args.get_int("pad_w"); + int y = conv_args.get_int("fil_h"); + int x = conv_args.get_int("fil_w"); + int ngroups = conv_args.get_int("group_count"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + int forw = conv_args.get_int("forw"); + + int need_fwd = (forw == 0 ? 1 : (forw & 1 ? 1 : 0)); + int need_bwd = (forw == 0 ? 1 : (forw & 2 ? 1 : 0)); + int need_wrw = (forw == 0 ? 1 : (forw & 4 ? 1 : 0)); + + // init host side + float *host_input = (float *)malloc(static_cast(n) * c * hi * wi * sizeof(float)); + float *host_weight = (float *)malloc(static_cast(k) * c * y * x * sizeof(float)); + float *host_output = (float *)malloc(static_cast(n) * k * ho * wo * sizeof(float)); + + float *device_input; + float *device_weight; + float *device_output; + + HIP_CALL(hipMalloc(&device_input, static_cast(n) * c * hi * wi * sizeof(float))); + HIP_CALL(hipMalloc(&device_weight, static_cast(k) * c * y * x * sizeof(float))); + HIP_CALL(hipMalloc(&device_output, static_cast(n) * k * ho * wo * sizeof(float))); + + + void *host_input_dtype = malloc(n * c * hi * wi * data_byte); + void *host_weight_dtype = malloc(k * c * y * x * data_byte); + void *host_output_dtype = malloc(n * k * ho * wo * data_byte); + + void *device_input_dtype; + void *device_weight_dtype; + void *device_output_dtype; + + HIP_CALL(hipMalloc(&device_input_dtype, n * c * hi * wi * data_byte)); + HIP_CALL(hipMalloc(&device_weight_dtype, k * c * y * x * data_byte)); + HIP_CALL(hipMalloc(&device_output_dtype, n * k * ho * wo * data_byte)); + + int need_verify = conv_args.get_int("verify"); + + int num_cu; + int num_simd = 64; // hard coded + int gcn_arch = 0; + + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + gcn_arch = dev_prop.gcnArch; + if(gcn_arch >= 1000) + num_cu *= 2; +#if 0 +#define P_DEVICE_PROP_INT(prop) \ + printf(#prop":%d\n", dev_prop.prop) + + + P_DEVICE_PROP_INT(clockRate); + P_DEVICE_PROP_INT(memoryClockRate); + P_DEVICE_PROP_INT(memoryBusWidth); + P_DEVICE_PROP_INT(major); + P_DEVICE_PROP_INT(minor); + P_DEVICE_PROP_INT(gcnArch); +#endif + } + + double theo_gflops = theoritical_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd * fp_factor); + double nrms = get_nrms(forw, driver_data_type); + + printf("num_cu:%d, gcn_arch:%d, theo_gflops:%f\n", num_cu, gcn_arch, theo_gflops); + + if (need_fwd){ + void *device_output_to_host = NULL; + if (need_verify) { + // gen rand + //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 0.0, 1.0); + //gen_rand_vector(host_weight, static_cast(k) * c * y * x, -0.5, 0.5); + gen_rand_vector(host_input, static_cast(n) * c * hi * wi, -5, 5); + gen_rand_vector(host_weight, static_cast(k) * c * y * x, -5, 5); + //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 1, 1); + //gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); + + if(driver_data_type == driverHalf){ + // move to different data type + tensor_copy(static_cast(host_input_dtype), host_input, static_cast(n) * c * hi * wi); + tensor_copy(static_cast(host_weight_dtype), host_weight, static_cast(k) * c * y * x); + } + else if(driver_data_type == driverInt8){ + // move to different data type + tensor_copy(static_cast(host_input_dtype), host_input, static_cast(n) * c * hi * wi); + tensor_copy(static_cast(host_weight_dtype), host_weight, static_cast(k) * c * y * x); + } + + HIP_CALL(hipMemcpy(device_input, host_input, + static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight, host_weight, + static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + + gpu_naive_conv_fwd_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + HIP_CALL(hipDeviceSynchronize()); + HIP_CALL(hipMemcpy(host_output, device_output, + static_cast(n) * k * ho * wo * sizeof(float), + hipMemcpyDeviceToHost)); + + if(driver_data_type == driverHalf || driver_data_type == driverInt8){ + device_output_to_host = malloc((static_cast(n) * k * ho * wo * data_byte + 3) / 4 * 4); + } + else{ + device_output_to_host = malloc(static_cast(n) * k * ho * wo * sizeof(float)); + } + + } + + if(driver_data_type == driverFloat){ + HIP_CALL(hipMemcpy(device_input, host_input, + static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight, host_weight, + static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + }else{ + HIP_CALL(hipMemcpy(device_input_dtype, host_input_dtype, + static_cast(n) * c * hi * wi * data_byte, hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight_dtype, host_weight_dtype, + static_cast(k) * c * y * x * data_byte, hipMemcpyHostToDevice)); + } + + igemm_fwd_btm_t conv_fwd_driver; + int valid_index = 0; + for (int i = 0; i < sizeof(igemm_fwd_btm_kernel_list)/sizeof(igemm_fwd_btm_kernel_list[0]); i++) { + igemm_fwd_btm_kernel_info_t *kinfo = &igemm_fwd_btm_kernel_list[i]; + if(driver_data_type == driverHalf){ + if(kinfo->data_type != "fp16") + continue; + } + else if(driver_data_type == driverInt8){ + if(kinfo->data_type != "int8") + continue; + } + + printf("[fwd:%2d] %s, ", valid_index, conv_fwd_driver.get_kernel_name(kinfo).c_str()); + fflush(stdout); + + result_t result; + + result = conv_fwd_driver.run(&conv_args, module, kinfo, device_input_dtype, + device_weight_dtype, device_output_dtype, warmup, repeat, driver_data_type); + valid_index++; + + if (result.return_code != 0){ + printf("not applicatble\n"); + continue; + } + + double gflops = measured_conv_gflops( + result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, + dilation_h, dilation_w, pad_h, pad_w, ngroups); + printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, + gflops / 1000 , (gflops / theo_gflops) * 100); + if (need_verify) { + bool is_valid; + if(driver_data_type == driverFloat) { + HIP_CALL(hipMemcpy(device_output_to_host, device_output, + static_cast(n) * k * ho * wo * sizeof(float), + hipMemcpyDeviceToHost)); + is_valid = valid_vector(host_output, static_cast(device_output_to_host), + static_cast(n) * k * ho * wo, nrms); + } + else if(driver_data_type == driverHalf || driver_data_type == driverInt8) { + HIP_CALL(hipMemcpy(device_output_to_host, device_output_dtype, + static_cast(n) * k * ho * wo * data_byte, + hipMemcpyDeviceToHost)); + if(dump_out) + dump_output_dword(static_cast(device_output_to_host), static_cast(n) * k * ho * wo / fp_factor); + if(driver_data_type == driverHalf) + is_valid = valid_vector(host_output, static_cast(device_output_to_host), + static_cast(n) * k * ho * wo, nrms); + else if(driver_data_type == driverInt8) + is_valid = valid_vector(host_output, static_cast(device_output_to_host), + static_cast(n) * k * ho * wo, nrms); + } + printf(", valid:%s", is_valid ? "y" : "n"); + } + printf("\n"); + } + + if (need_verify){ + free(device_output_to_host); + } + } +} \ No newline at end of file