diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e5254034afc..919a7aa8c05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -724,7 +724,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli and logits == "f" and bias == "no" and dropout == "f" - and lse == "f" and skip == "f" ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index b2c1b06955a..1d998ba4f66 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -211,10 +211,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload set_tile(lse_acc, -numeric::infinity()); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // Note: here occ are all cleard, return it @@ -256,8 +253,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); auto k_lds_write_view = make_tensor_view( static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); @@ -289,8 +288,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload Policy::template MakeSRegTileDistribution()); // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); auto v_lds_write_view = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + @@ -393,7 +394,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { if(i_total_loops == (num_total_loop - 1)) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = + make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); set_tile_if(s_acc, -numeric::infinity(), [&, @@ -410,7 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); @@ -602,10 +604,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload } }); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // finally, O @@ -717,10 +716,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload set_tile(lse_acc, -numeric::infinity()); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // Note: here occ are all cleard, return it @@ -765,8 +761,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); auto k_lds_write_view = make_tensor_view( static_cast(smem_ptrk0), @@ -801,8 +799,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload Policy::template MakeSRegTileDistribution()); // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); auto v_lds_write_view = make_tensor_view( reinterpret_cast(static_cast(smem_ptrv0)), @@ -901,7 +901,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { if(i_total_loops == (num_total_loop - 1)) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = + make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); set_tile_if(s_acc, -numeric::infinity(), [&, @@ -918,7 +919,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); @@ -1146,10 +1147,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload } }); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // finally, O