From 041ebd1a4e543374655a098badc262786d276505 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 15 Oct 2025 09:14:20 +0000 Subject: [PATCH 1/3] enable storelse for fmha_fwd_trload kernel --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f898d5f7b26..533f7f2f231 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -608,7 +608,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and skip == "f": pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": From 52e25aca0d90f3a2d0efc95e28a3359d5091b278 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 17 Oct 2025 09:40:49 +0000 Subject: [PATCH 2/3] fix lse in trload --- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index b2c1b06955a..7f14ffd1f96 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -211,10 +211,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload set_tile(lse_acc, -numeric::infinity()); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // Note: here occ are all cleard, return it @@ -602,10 +599,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload } }); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // finally, O @@ -717,10 +711,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload set_tile(lse_acc, -numeric::infinity()); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // Note: here occ are all cleard, return it @@ -1146,10 +1137,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload } }); - if(get_thread_local_1d_id() < kM0) - { - store_tile(lse_acc_dram_window_tmp, lse_acc); - } + store_tile(lse_acc_dram_window_tmp, lse_acc); } // finally, O From 9fddc0b6eda835ddcdc6bb955aa045ef22bb846d Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 22 Oct 2025 08:03:00 +0000 Subject: [PATCH 3/3] fix the mask related bug --- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 7f14ffd1f96..1d998ba4f66 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -253,8 +253,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); auto k_lds_write_view = make_tensor_view( static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); @@ -286,8 +288,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload Policy::template MakeSRegTileDistribution()); // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); auto v_lds_write_view = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + @@ -390,7 +394,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { if(i_total_loops == (num_total_loop - 1)) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = + make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); set_tile_if(s_acc, -numeric::infinity(), [&, @@ -407,7 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{}); @@ -756,8 +761,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; - auto k_dram_window = make_tile_window( - k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); auto k_lds_write_view = make_tensor_view( static_cast(smem_ptrk0), @@ -792,8 +799,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload Policy::template MakeSRegTileDistribution()); // V tile in LDS - auto v_dram_window = make_tile_window( - v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp, + {physical_seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); auto v_lds_write_view = make_tensor_view( reinterpret_cast(static_cast(smem_ptrv0)), @@ -892,7 +901,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { if(i_total_loops == (num_total_loop - 1)) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = + make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); set_tile_if(s_acc, -numeric::infinity(), [&, @@ -909,7 +919,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0); bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number{}, number{});