From 57a47b091ce52d00283a6277d9eccd50358a58dc Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 2 Nov 2025 22:37:51 -0600 Subject: [PATCH 01/48] Let fmha_fwd_v3() compatible with fmha_fwd() --- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 54 ++++---- example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 38 ++---- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 61 +-------- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 128 +++++++++++------- .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 3 +- .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 3 +- .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 3 +- .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 3 +- .../pipeline/block_fmha_pipeline_enum.hpp | 1 + 9 files changed, 128 insertions(+), 166 deletions(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 7ddb65a2db..c713560045 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -87,12 +87,10 @@ struct Problem { explicit Problem(const ck_tile::ArgParser& args) { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 - : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; - batch = args.get_int("b"); - seqlen_q = args.get_int("s"); - seqlen_k = args.get_int("s_k"); + prec = args.get_str("prec") == "fp16" ? "fp16" : "bf16"; + batch = args.get_int("b"); + seqlen_q = args.get_int("s"); + seqlen_k = args.get_int("s_k"); if(seqlen_k < 0) { seqlen_k = seqlen_q; @@ -172,7 +170,7 @@ struct Problem } } - ck_tile::fmha_fwd_v3_args::data_type_enum data_type; + std::string prec; ck_tile::index_t batch; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -342,17 +340,27 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // Ensure output buffer is zero-initialized so padded regions compare cleanly o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args{}; - - args.data_type = problem.data_type; - args.batch = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; - args.nhead_q = problem.nhead_q; - args.nhead_kv = problem.nhead_kv; - args.hdim_qk = problem.hdim; - args.hdim_v = problem.hdim; - args.softmax_scale = problem.softmax_scale; + fmha_fwd_traits traits{}; + traits.hdim_q = problem.hdim; + traits.hdim_v = problem.hdim; + traits.data_type = problem.prec; + traits.is_v_rowmajor = true; + traits.is_group_mode = false; + traits.has_logits_soft_cap = false; + traits.mask_type = mask_enum::mask_bottom_right; + traits.bias_type = bias_enum::no_bias; + traits.has_lse = false; + traits.do_fp8_static_quant = false; + + fmha_fwd_args args{}; + args.batch = problem.batch; + args.seqlen_q = problem.seqlen_q; + args.seqlen_k = problem.seqlen_k; + args.nhead_q = problem.nhead_q; + args.nhead_k = problem.nhead_kv; + args.hdim_q = problem.hdim; + args.hdim_v = problem.hdim; + args.scale_s = problem.softmax_scale; args.window_size_left = problem.mask.left; args.window_size_right = problem.mask.right; @@ -445,7 +453,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.cu_seqlen_q_ptr = !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) : nullptr; - args.cu_seqlen_kv_ptr = + args.cu_seqlen_k_ptr = !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) : nullptr; @@ -455,8 +463,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) run_config.kernel_warmup, run_config.kernel_repeat}; - auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); - if(!result) + float time = ck_tile::fmha_fwd_v3(traits, args, stream_config); + if(time < 0.f) { std::cerr << "faild to run fmha_fwd_v3()" << std::endl; return false; @@ -477,7 +485,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }(); float tflops = static_cast(flop) / 1.e9 / time; - std::cout << "[" << problem.data_type << "|"; + std::cout << "[" << problem.prec << "|"; if(problem.input_layout == problem.output_layout) { std::cout << problem.input_layout; @@ -602,7 +610,7 @@ int main(int argc, char* argv[]) RunConfig run_config(args); const auto run = [&] { - if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) + if(problem.prec == "fp16") { return run_impl(problem, run_config); } diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp index 30019167fb..041e04328d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp @@ -7,54 +7,40 @@ namespace ck_tile { -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) +float fmha_fwd_v3(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { - switch(data_type) - { - case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; - case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; - default: return stream << "unknown"; - } -} - -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) -{ - if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) + if(traits.data_type.compare("fp16") == 0) { if(args.mask_type == static_cast(mask_enum::no_mask)) { - using kernel_traits = - fmha_fwd_v3_kernel_traits; + using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(args, config); + return fmha_fwd_v3_kernel_dispatch(config, args); } else { - using kernel_traits = - fmha_fwd_v3_kernel_traits; + using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(args, config); + return fmha_fwd_v3_kernel_dispatch(config, args); } } - else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) + else if(traits.data_type.compare("bf16") == 0) { if(args.mask_type == static_cast(mask_enum::no_mask)) { - using kernel_traits = - fmha_fwd_v3_kernel_traits; + using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(args, config); + return fmha_fwd_v3_kernel_dispatch(config, args); } else { - using kernel_traits = - fmha_fwd_v3_kernel_traits; + using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(args, config); + return fmha_fwd_v3_kernel_dispatch(config, args); } } - return std::make_pair(false, -1.f); + return -1.; } } // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 4bd1d1a367..c3a0d0d8f3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -9,65 +9,10 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/stream_config.hpp" -namespace ck_tile { - -struct fmha_fwd_v3_args -{ - enum class data_type_enum - { - fp16, - bf16 - }; - - data_type_enum data_type; - // bool is_varlen; - - index_t batch; - index_t seqlen_q; - index_t seqlen_k; - index_t nhead_q; - index_t nhead_kv; - index_t hdim_qk; - index_t hdim_v; - - float softmax_scale; - - index_t window_size_left; - index_t window_size_right; - index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and - // window_size_right == 0). +#include "fmha_fwd.hpp" - const void* q_ptr; - index_t stride_q; - index_t nhead_stride_q; - index_t batch_stride_q; - - const void* k_ptr; - index_t stride_k; - index_t nhead_stride_k; - index_t batch_stride_k; - - const void* v_ptr; - index_t stride_v; - index_t nhead_stride_v; - index_t batch_stride_v; - - void* o_ptr; - index_t stride_o; - index_t nhead_stride_o; - index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] -}; - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); +namespace ck_tile { -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); +float fmha_fwd_v3(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); } // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 194675f962..451ebadd21 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -17,25 +17,26 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "fmha_fwd.hpp" #include "fmha_fwd_v3.hpp" #include "mask.hpp" -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - std::pair fmha_fwd_v3_kernel_dispatch( \ - const fmha_fwd_v3_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - fmha_fwd_v3_kernel_launch(args, config)); \ +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ + template <> \ + float fmha_fwd_v3_kernel_dispatch( \ + const ck_tile::stream_config& config, fmha_fwd_args args) \ + { \ + return fmha_fwd_v3_kernel_launch::type>(config, \ + args); \ } namespace ck_tile { -template +template struct fmha_fwd_v3_problem_traits; template <> -struct fmha_fwd_v3_problem_traits +struct fmha_fwd_v3_problem_traits { using qkvp_dtype = ck_tile::half_t; using acc_dtype = float; @@ -44,7 +45,7 @@ struct fmha_fwd_v3_problem_traits }; template <> -struct fmha_fwd_v3_problem_traits +struct fmha_fwd_v3_problem_traits { using qkvp_dtype = ck_tile::bf16_t; using acc_dtype = float; @@ -52,15 +53,45 @@ struct fmha_fwd_v3_problem_traits using lse_dtype = float; }; -template -struct fmha_fwd_v3_kernel_traits +template +using fmha_fwd_v3_kernel_traits = + fmha_fwd_traits_<128, + DataType, + kIsGroupMode, + 256, + 32, + 128, + 128, + 32, + 128, + true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3, + false, + ck_tile::GenericAttentionMask, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false, + false, + true, + false>; + +template +struct get_fmha_fwd_v3_kernel { - static constexpr auto date_type = DataType; - static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; + using data_type = KernelTraits::DataType; + static constexpr bool kIsGroupMode = KernelTraits::kIsGroupMode; // M0 N0 K0 N1 K1 - using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; + using fmha_block_tile = sequence; using fmha_warp_gemm_shape = sequence<32, 32, 16>; using fmha_block_warps = sequence<8, 1, 1>; @@ -69,49 +100,48 @@ struct fmha_fwd_v3_kernel_traits fmha_warp_gemm_shape, fmha_block_warps, fmha_warp_gemm_shape, - true // IsVLayoutRowMajor - >; - - using fmha_traits = TileFmhaFwdV3Traits; + + using fmha_traits = TileFmhaFwdV3Traits; - using fmha_mask = GenericAttentionMask; + using fmha_mask = KernelTraits::FmhaMask; using fmha_pipeline_problem = - BlockFmhaFwdV3PipelineProblem::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::lse_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, + BlockFmhaFwdV3PipelineProblem::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::lse_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::o_dtype, fmha_shape, - IsVariableSeqlen, + kIsGroupMode, fmha_mask, fmha_traits>; using fmha_pipeline = BlockFmhaFwdV3Pipeline; using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, + Default2DEpilogueProblem::acc_dtype, + typename fmha_fwd_v3_problem_traits::o_dtype, true, // kPadM true, // kPadM true // UseRawStore >>; - using kernel = FmhaFwdV3Kernel; + using type = FmhaFwdV3Kernel; }; template -float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) +float fmha_fwd_v3_kernel_launch(const ck_tile::stream_config& config, fmha_fwd_args args) { /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly /// maximizes the kernel's performance. @@ -136,11 +166,11 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.o_ptr, args.seqlen_q, args.seqlen_k, - args.hdim_qk, + args.hdim_q, args.hdim_v, args.nhead_q, - args.nhead_q / args.nhead_kv, - args.softmax_scale, + args.nhead_q / args.nhead_k, + args.scale_s, args.stride_q, args.stride_k, args.stride_v, @@ -159,8 +189,8 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_right, args.mask_type, remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + static_cast(args.cu_seqlen_q_ptr), + static_cast(args.cu_seqlen_k_ptr)); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); @@ -169,11 +199,7 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); } -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -template -std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, - const stream_config& config); +template +float fmha_fwd_v3_kernel_dispatch(const ck_tile::stream_config&, fmha_fwd_args); } // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp index 2dbe0b2098..0d199aa33f 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp @@ -6,8 +6,7 @@ namespace ck_tile { -using kernel_traits = - fmha_fwd_v3_kernel_traits; +using kernel_traits = fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp index 6f5eca97a1..a371d74a80 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp @@ -6,8 +6,7 @@ namespace ck_tile { -using kernel_traits = - fmha_fwd_v3_kernel_traits; +using kernel_traits = fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp index 1c4c798af6..b0fbc88f78 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp @@ -6,8 +6,7 @@ namespace ck_tile { -using kernel_traits = - fmha_fwd_v3_kernel_traits; +using kernel_traits = fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp index 077cb7b73c..bd1860fb25 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp @@ -6,8 +6,7 @@ namespace ck_tile { -using kernel_traits = - fmha_fwd_v3_kernel_traits; +using kernel_traits = fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 45a1c8f4b8..88d6825c55 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_ASYNC_TRLOAD_V3, }; template From b93a0adfb5443ed4d8ad662a620fa09f40889171 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 3 Nov 2025 04:00:20 -0600 Subject: [PATCH 02/48] Decouple get_fwd_blobs() and FmhaFwdKernel --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 ++ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 30 ++++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 4098eb67c2..5c11d3b40b 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -113,6 +113,7 @@ def get_mask_check_map(mask: str): "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", } PIPELINE_ENUM_MAP = { @@ -122,6 +123,7 @@ def get_mask_check_map(mask: str): "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 2acc467410..e7ec69c586 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -7,9 +7,10 @@ import itertools import os from collections import OrderedDict +from functools import partial from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import ClassVar, List, Optional, Tuple from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR @@ -39,7 +40,7 @@ #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY = """ +FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) @@ -546,9 +547,11 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str - @property - def template(self) -> str: - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + + def render(self) -> str: + return type(self).KERNEL_HEADER + type(self).KERNEL_BODY_TEMPLATE.format( F_idx=self.F_idx, F_arch=self.F_arch, F_hdim=self.F_hdim, @@ -634,6 +637,17 @@ def api_trait(self) -> FmhaFwdApiTrait: ) +@dataclass +class FmhaFwdV3Kernel(FmhaFwdKernel): + KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + + +def create_kernel(pipeline: FmhaFwdPipeline, *args, **kwargs) -> FmhaFwdKernel: + ctor = FmhaFwdV3Kernel if pipeline.tag == "qr_async_trload_v3" else FmhaFwdKernel + builder = partial(ctor, F_pipeline=pipeline) + return builder(*args, **kwargs) + + class KernelComponentFactoryGfx9: arch = ArchTrait( "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" @@ -911,14 +925,14 @@ def get_fwd_blobs( or pipeline.F_logits == "f" ): continue - k = FmhaFwdKernel( + k = create_kernel( + pipeline, F_arch=factory.arch, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, - F_pipeline=pipeline, mask_impl=mask_impl, ) if kernel_filter != "": @@ -1012,7 +1026,7 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + update_file(autogen_dir / kernel.filename, kernel.render()) def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: From 6e463661f30b083bdeb534214cd803be28340986 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 4 Nov 2025 04:10:56 -0600 Subject: [PATCH 03/48] Decouple compatibility checks from get_fwd_blobs() --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 208 +++++++++++++----- 1 file changed, 159 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e7ec69c586..8ea0efb36c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -10,7 +10,7 @@ from functools import partial from dataclasses import dataclass, field from pathlib import Path -from typing import ClassVar, List, Optional, Tuple +from typing import Callable, ClassVar, Iterable, List, Optional, Tuple from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR @@ -121,6 +121,9 @@ #endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ +FMHA_FWD_V3_KERNEL_HEADER = "" +FMHA_FWD_V3_KERNEL_BODY_TEMPLATE = "" + FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" FMHA_FWD_API = """ #include @@ -639,16 +642,145 @@ def api_trait(self) -> FmhaFwdApiTrait: @dataclass class FmhaFwdV3Kernel(FmhaFwdKernel): - KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + KERNEL_HEADER: ClassVar[str] = FMHA_FWD_V3_KERNEL_HEADER + KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_V3_KERNEL_BODY_TEMPLATE + + +@dataclass +class ProblemContext: + dtype: str + mode: str + hdim: int + hdim_v: int + + +@dataclass +class KernelContext: + arch: ArchTrait + tile: FmhaFwdTileSize + pipeline: FmhaFwdPipeline + mask_impl: str + + +CompatibilityRule = Callable[[ProblemContext, KernelContext], bool] + + +def is_compatible( + problem_ctx: ProblemContext, + kernel_ctx: KernelContext, + rules: Iterable[CompatibilityRule], + *, + short_circuit: bool = True, +) -> bool: + if short_circuit: + for rule in rules: + if not rule(problem_ctx, kernel_ctx): + return False + return True + return all(rule(problem_ctx, kernel_ctx) for rule in rules) + + +def create_kernel( + problem_ctx: ProblemContext, kernel_ctx: KernelContext +) -> FmhaFwdKernel: + ctor = ( + FmhaFwdV3Kernel + if kernel_ctx.pipeline.tag == "qr_async_trload_v3" + else FmhaFwdKernel + ) + return ctor( + F_idx=0, + F_dtype=problem_ctx.dtype, + F_mode=problem_ctx.mode, + F_hdim=problem_ctx.hdim, + F_arch=kernel_ctx.arch, + F_tile=kernel_ctx.tile, + F_pipeline=kernel_ctx.pipeline, + mask_impl=kernel_ctx.mask_impl, + ) + + +class CompatibilityRuleFactory: + @staticmethod + def get_rules() -> list[CompatibilityRule]: + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + if problem_ctx.mode == "group": + if ( + kernel_ctx.pipeline.F_spad != "t" + or kernel_ctx.pipeline.F_skpad != "t" + ): + return False + return True + + def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if (problem_ctx.hdim, problem_ctx.hdim_v) == (192, 128): + if ( + kernel_ctx.pipeline.F_bias != "no" + or kernel_ctx.pipeline.F_dropout == "t" + ): + False + return True + + def check_feature( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + # logits_soft_cap is only allowed if no bias + if not ( + ( + kernel_ctx.pipeline.F_logits == "t" + and kernel_ctx.pipeline.F_bias == "no" + ) + or kernel_ctx.pipeline.F_logits == "f" + ): + return False + return True + + return [check_mode, check_hdim, check_feature] + +class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): + @staticmethod + def get_rules() -> list[CompatibilityRule]: + rules = CompatibilityRuleFactory.get_rules() + + def check_hdim_tile_for_non_fp32( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.arch.name.startswith("gfx9") and problem_ctx.dtype != "fp32": + # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support + if kernel_ctx.pipeline.tag != "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 != 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) + and kernel_ctx.tile.F_bm0 != 128 + ) + ): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + return False + if kernel_ctx.pipeline.tag == "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 == 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) + not in [(64, 64), (128, 128)] + ) + ): + return False + return True -def create_kernel(pipeline: FmhaFwdPipeline, *args, **kwargs) -> FmhaFwdKernel: - ctor = FmhaFwdV3Kernel if pipeline.tag == "qr_async_trload_v3" else FmhaFwdKernel - builder = partial(ctor, F_pipeline=pipeline) - return builder(*args, **kwargs) + rules.append(check_hdim_tile_for_non_fp32) + return rules -class KernelComponentFactoryGfx9: +class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): arch = ArchTrait( "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) @@ -761,7 +893,9 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli return pipelines -class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): +class KernelComponentFactoryGfx950( + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9 +): arch = ArchTrait("gfx950") @staticmethod @@ -791,7 +925,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli return pipelines -class KernelComponentFactoryGfx12: +class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") @staticmethod @@ -847,7 +981,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli return pipelines -class CustomFactory(KernelComponentFactoryGfx9): +class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): @staticmethod def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) @@ -890,57 +1024,33 @@ def get_fwd_blobs( for ((hdim, hdim_v), tiles), mode in itertools.product( d.items(), MODE_MAP.keys() ): + if optdim_list != [-1]: + if hdim not in optdim_list: + continue for tile, next_tile in zip(tiles, tiles[1:]): assert next_tile.F_bm0 >= tile.F_bm0, ( "Tiles must be ordered by increasing bm0" ) + for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": - continue - if factory.arch.name.startswith("gfx9") and dtype != "fp32": - # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" + problem_ctx = ProblemContext( + dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v + ) + kernel_ctx = KernelContext( + arch=factory.arch, tile=tile, pipeline=pipeline, mask_impl=mask_impl + ) + rules = factory.get_rules() + if not is_compatible( + problem_ctx, kernel_ctx, rules, short_circuit=True ): continue - k = create_kernel( - pipeline, - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - mask_impl=mask_impl, - ) + + k = create_kernel(problem_ctx, kernel_ctx) if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ["fp16", "bf16"] From 756a1b844cae94648e7cdafe6006217f1a3836fb Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 4 Nov 2025 08:34:44 -0600 Subject: [PATCH 04/48] Extract product feature checks out from get_fwd_blobs() --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 208 +++++++++++------- 1 file changed, 123 insertions(+), 85 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 8ea0efb36c..28fb4b91ec 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -7,7 +7,6 @@ import itertools import os from collections import OrderedDict -from functools import partial from dataclasses import dataclass, field from pathlib import Path from typing import Callable, ClassVar, Iterable, List, Optional, Tuple @@ -656,7 +655,6 @@ class ProblemContext: @dataclass class KernelContext: - arch: ArchTrait tile: FmhaFwdTileSize pipeline: FmhaFwdPipeline mask_impl: str @@ -681,7 +679,7 @@ def is_compatible( def create_kernel( - problem_ctx: ProblemContext, kernel_ctx: KernelContext + arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> FmhaFwdKernel: ctor = ( FmhaFwdV3Kernel @@ -690,10 +688,10 @@ def create_kernel( ) return ctor( F_idx=0, + F_arch=arch, F_dtype=problem_ctx.dtype, F_mode=problem_ctx.mode, F_hdim=problem_ctx.hdim, - F_arch=kernel_ctx.arch, F_tile=kernel_ctx.tile, F_pipeline=kernel_ctx.pipeline, mask_impl=kernel_ctx.mask_impl, @@ -748,7 +746,7 @@ def get_rules() -> list[CompatibilityRule]: def check_hdim_tile_for_non_fp32( problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> bool: - if kernel_ctx.arch.name.startswith("gfx9") and problem_ctx.dtype != "fp32": + if problem_ctx.dtype != "fp32": # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support if kernel_ctx.pipeline.tag != "qr_async_trload" and ( ( @@ -1008,6 +1006,121 @@ def get_factory(target: str): raise Exception(f"Unsupported device target {target}") +@dataclass(frozen=True) +class Product: + name: str + rule: CompatibilityRule + + def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return self.rule(problem_ctx, kernel_ctx) + + +def get_product(receipt: int) -> Product: + # Flash attention integration + if receipt in (2, 3): + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_squant == "f" + cond &= kernel_ctx.pipeline.F_skip == "f" + return cond + + return Product(name="Flash attention integration", rule=fit) + # PyTorch integration + elif receipt == 4: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "bias"] + cond &= kernel_ctx.pipeline.F_squant == "f" + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="PyTorch integration", rule=fit) + # Aiter(mha_fwd) integration + elif receipt == 100: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "group" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 + return cond + + return Product(name="Aiter(mha_fwd) integration", rule=fit) + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "group" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 + return cond + + return Product(name="Aiter(mha_varlen_fwd) integration", rule=fit) + # aiter::mha_fwd C++ api integration + elif receipt == 600: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 + return cond + + return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + elif receipt == 888: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp8", "fp8bf16", "fp8fp32"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= problem_ctx.hdim == 128 + return cond + + return Product(name="receipt = 888", rule=fit) + # fp32 only, all variations + elif receipt == 800: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="fp32 only, all variations", rule=fit) + # fp32 only, minimal set of parameters + elif receipt == 801: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= problem_ctx.hdim in [48, 128] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_bias == "no" + cond &= kernel_ctx.pipeline.F_lse == "f" + cond &= kernel_ctx.pipeline.F_dropout == "f" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_mask == "s_no" + return cond + + return Product(name="fp32 only, minimal set of parameters", rule=fit) + # Don't build fp32 by default + else: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return problem_ctx.dtype != "fp32" + + return Product(name="Default", rule=fit) + + def get_fwd_blobs( targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: @@ -1039,95 +1152,20 @@ def get_fwd_blobs( dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v ) kernel_ctx = KernelContext( - arch=factory.arch, tile=tile, pipeline=pipeline, mask_impl=mask_impl + tile=tile, pipeline=pipeline, mask_impl=mask_impl ) rules = factory.get_rules() + product = get_product(receipt) + if not is_compatible( - problem_ctx, kernel_ctx, rules, short_circuit=True + problem_ctx, kernel_ctx, [*rules, product], short_circuit=True ): continue - k = create_kernel(problem_ctx, kernel_ctx) + k = create_kernel(factory.arch, problem_ctx, kernel_ctx) if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" - cond &= pipeline.F_skip == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" - cond &= mode == "batch" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 - if not cond: - continue - elif receipt == 888: - cond = dtype in ["fp8", "fp8bf16", "fp8fp32"] - cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == "fp32" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == "fp32" - cond &= hdim in [48, 128] - cond &= mode == "batch" - cond &= pipeline.F_bias == "no" - cond &= pipeline.F_lse == "f" - cond &= pipeline.F_dropout == "f" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - cond &= pipeline.F_mask == "s_no" - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == "fp32": - continue api_pool.register_traits(k.api_trait()) gen.append(k) From 4c5a68e59e7e560fe13b4e4563e8f1e6352af9e0 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 4 Nov 2025 09:57:32 -0600 Subject: [PATCH 05/48] Remove duplicated code in factories and redundant checks --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 94 +++++++++++-------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 28fb4b91ec..a5236a0cc3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -783,11 +783,22 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) + _DT_FP32 = frozenset({"fp32"}) + _DT_FP16_BF16 = frozenset({"fp16", "bf16"}) + _DT_FP8_FP8BF16 = frozenset({"fp8", "fp8bf16"}) + _DT_FP8FP32 = frozenset({"fp8fp32"}) + + @classmethod + def supported_dtypes(cls) -> frozenset[str]: + return frozenset().union( + cls._DT_FP32, cls._DT_FP16_BF16, cls._DT_FP8_FP8BF16, cls._DT_FP8FP32 + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp32"]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP32: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -800,7 +811,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: return { ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -816,29 +827,31 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8_FP8BF16: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ["fp32"]: + if dtype in cls._DT_FP32: squant = "f" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -851,7 +864,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: squant = "f" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -876,18 +889,13 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8_FP8BF16 | cls._DT_FP8FP32: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - elif dtype in ["fp8fp16", "bf8"]: - # TODO - None - else: - assert False return pipelines @@ -896,12 +904,14 @@ class KernelComponentFactoryGfx950( ): arch = ArchTrait("gfx950") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( dtype, hdim, hdim_v, receipt, mask_impl ) - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: squant = "f" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -926,9 +936,19 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp16", "bf16"]: + _DT_FP16_BF16 = frozenset({"fp16", "bf16"}) + _DT_FP8_FP8BF16 = frozenset({"fp8", "fp8bf16"}) + _DT_FP8FP32 = frozenset({"fp8fp32"}) + + @classmethod + def supported_dtypes(cls) -> frozenset[str]: + return frozenset().union( + cls._DT_FP16_BF16, cls._DT_FP8_FP8BF16, cls._DT_FP8FP32 + ) + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -937,25 +957,27 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8_FP8BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { # bm0, bn0, bk0, bn1, bk1, (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = [] - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: squant = "f" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -967,23 +989,21 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8_FP8BF16 | cls._DT_FP8FP32: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - else: - assert False return pipelines class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": + if dtype in cls._DT_FP16_BF16: if (128, 128) in result.keys(): result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result @@ -1129,10 +1149,8 @@ def get_fwd_blobs( factories = get_factories_for_targets(targets, get_factory) - for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()): d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product( d.items(), MODE_MAP.keys() From 41cd25b959c20cc6b0ebb61277899fac32aaec83 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 4 Nov 2025 23:34:06 -0600 Subject: [PATCH 06/48] Remove FmhaFwdKernel<>::GetName() --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 3 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 47 ------------------- 2 files changed, 2 insertions(+), 48 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index a5236a0cc3..25dbb0f3c5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -110,7 +110,7 @@ {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; + std::cout << ", {F_kname}" << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; @@ -555,6 +555,7 @@ class FmhaFwdKernel: def render(self) -> str: return type(self).KERNEL_HEADER + type(self).KERNEL_BODY_TEMPLATE.format( F_idx=self.F_idx, + F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fe7c8d48c8..d856d296b1 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -71,53 +71,6 @@ struct FmhaFwdKernel #endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs From 3e0ad2c011f0d5401474c7d0b253d3cd0c022879 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 5 Nov 2025 01:50:25 -0600 Subject: [PATCH 07/48] Let FmhaFwdApiPool support pipelines with different mask_impl --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 18 ++++++-- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 43 +++++++++++-------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 5c11d3b40b..7dd0b4ee20 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -31,16 +31,24 @@ } -def get_mask_map(mask: str): - if mask == "generic": +def get_mask_map(mask_impl: str): + if mask_impl == "generic": return _MASK_MAP - elif mask == "simplified": + elif mask_impl == "simplified": return _MASK_SIMPLIFIED_MAP else: assert False return None +def get_mask_impl(mask: str) -> str: + return "simplified" if mask.startswith("s_") else "generic" + + +def get_mask_cpp_type(mask: str) -> str: + return get_mask_map(get_mask_impl(mask))[mask] + + _MASK_CHECK_MAP = { "no": "t.mask_type == mask_enum::no_mask", "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", @@ -63,6 +71,10 @@ def get_mask_check_map(mask: str): return None +def get_mask_cpp_check_expr(mask: str) -> str: + return get_mask_check_map(get_mask_impl(mask))[mask] + + BIAS_MAP = { "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", "bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 25dbb0f3c5..3b2919ff87 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -24,6 +24,8 @@ FWD_DTYPE_MAP, BIAS_MAP, get_mask_map, + get_mask_cpp_type, + get_mask_cpp_check_expr, ) from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file @@ -120,8 +122,8 @@ #endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ -FMHA_FWD_V3_KERNEL_HEADER = "" -FMHA_FWD_V3_KERNEL_BODY_TEMPLATE = "" +FMHA_FWD_V3_KERNEL_HEADER = "// this is fmha fwd v3 kernel header\n" +FMHA_FWD_V3_KERNEL_BODY_TEMPLATE = "// this is fmha fwd v3 kernel body\n" FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" FMHA_FWD_API = """ @@ -255,7 +257,7 @@ def name(self) -> str: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.pipeline_tag in ["qr_async", "qr_async_trload", "qr_async_trload_v3"]: if self.spad == "t": return "true" # always support else: @@ -288,7 +290,7 @@ def skcheck(self) -> str: return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag == "qr_async_trload": + elif self.pipeline_tag in ["qr_async_trload", "qr_async_trload_v3"]: if self.skpad == "t": return "true" else: @@ -304,7 +306,7 @@ def dcheck(self) -> str: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -321,7 +323,7 @@ def dvcheck(self) -> str: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -423,9 +425,8 @@ def pad_name() -> str: class FmhaFwdApiPool: - def __init__(self, mask_impl): + def __init__(self): self.pool = OrderedDict() - self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: hdim = trait.hdim, trait.bn1 @@ -457,8 +458,8 @@ def api(self) -> str: F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_mask=get_mask_cpp_type(trait.mask), + F_mask_check=get_mask_cpp_check_expr(trait.mask), F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], @@ -547,13 +548,12 @@ class FmhaFwdKernel: F_mode: str # value from MODE_MAP F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline - mask_impl: str - KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER - KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE def render(self) -> str: - return type(self).KERNEL_HEADER + type(self).KERNEL_BODY_TEMPLATE.format( + return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( F_idx=self.F_idx, F_kname=self.name, F_arch=self.F_arch, @@ -590,7 +590,7 @@ def render(self) -> str: F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), F_mode=MODE_MAP[self.F_mode], F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], @@ -642,8 +642,8 @@ def api_trait(self) -> FmhaFwdApiTrait: @dataclass class FmhaFwdV3Kernel(FmhaFwdKernel): - KERNEL_HEADER: ClassVar[str] = FMHA_FWD_V3_KERNEL_HEADER - KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_V3_KERNEL_BODY_TEMPLATE + _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_V3_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_V3_KERNEL_BODY_TEMPLATE @dataclass @@ -695,7 +695,6 @@ def create_kernel( F_hdim=problem_ctx.hdim, F_tile=kernel_ctx.tile, F_pipeline=kernel_ctx.pipeline, - mask_impl=kernel_ctx.mask_impl, ) @@ -931,6 +930,12 @@ def get_pipelines( ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + """ + # qr_async_trload_v3 only supports (generic) causal mask + for mask in ["no", "causal"]: + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "t", "t", + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_squant="f", F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + """ return pipelines @@ -1146,7 +1151,7 @@ def get_fwd_blobs( targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() - api_pool = FmhaFwdApiPool(mask_impl) + api_pool = FmhaFwdApiPool() factories = get_factories_for_targets(targets, get_factory) From 4e6153b93b006f28f56c3654fc71a1478e34acb4 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 5 Nov 2025 13:48:47 -0600 Subject: [PATCH 08/48] Add tile setting for fmha fwd v3 pipeline --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 93 ++++++++++++++----- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 3b2919ff87..b44eb79021 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -739,17 +739,20 @@ def check_feature( class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): - @staticmethod - def get_rules() -> list[CompatibilityRule]: + _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: rules = CompatibilityRuleFactory.get_rules() - def check_hdim_tile_for_non_fp32( + def check_hdim_tile( problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> bool: if problem_ctx.dtype != "fp32": # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if kernel_ctx.pipeline.tag != "qr_async_trload" and ( - ( + if ( + kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES + and ( (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) and kernel_ctx.tile.F_bn0 != 128 ) @@ -761,20 +764,54 @@ def check_hdim_tile_for_non_fp32( # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 return False - if kernel_ctx.pipeline.tag == "qr_async_trload" and ( - ( - (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) - and kernel_ctx.tile.F_bn0 == 128 - ) - or ( - (problem_ctx.hdim, problem_ctx.hdim_v) - not in [(64, 64), (128, 128)] - ) - ): - return False return True - rules.append(check_hdim_tile_for_non_fp32) + rules.append(check_hdim_tile) + return rules + + +class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): + _AVAILABLE_PIPELINES = ( + CompatibilityRuleFactoryGfx9._AVAILABLE_PIPELINES + | frozenset({"qr_async_trload", "qr_async_trload_v3"}) + ) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactoryGfx9.get_rules() + + def check_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.pipeline.tag == "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 == 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)] + ) + ): + return False + + # only qr_async_trload_v3 use km0=256 & 8-warps + is_v3_dedicated_tile = ( + kernel_ctx.tile.F_bm0 == 256 + and (kernel_ctx.tile.F_rm0 * kernel_ctx.tile.F_rn0 * kernel_ctx.tile.F_rk0) == 8 + and (kernel_ctx.tile.F_rm1 * kernel_ctx.tile.F_rn1 * kernel_ctx.tile.F_rk1) == 8 + ) # fmt: skip + is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" + return is_v3_dedicated_tile == is_v3_pipeline + + # qr_async_trload_v3 only support batch mode for now + def check_mode_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.pipeline.tag == "qr_async_trload_v3": + return problem_ctx.mode == "batch" + return True + + rules.extend([check_tile_pipeline, check_mode_pipeline]) return rules @@ -900,10 +937,20 @@ def get_pipelines( class KernelComponentFactoryGfx950( - KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9 + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950 ): arch = ArchTrait("gfx950") + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + # add tile for qr_async_trload_v3 + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + return result + @classmethod def get_pipelines( cls, dtype, hdim, hdim_v, receipt, mask_impl @@ -931,10 +978,12 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip """ - # qr_async_trload_v3 only supports (generic) causal mask - for mask in ["no", "causal"]: - pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "t", "t", - F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_squant="f", F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now + if (hdim, hdim_v) == (128, 128): + # qr_async_trload_v3 only supports (generic) causal mask + for mask in ["no", "causal"]: + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "t", "t", + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_squant=squant, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip """ return pipelines From 6eaa880217f5dbd66add83e13d3d3272d9b4f0af Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 5 Nov 2025 14:13:35 -0600 Subject: [PATCH 09/48] Add fwd v3 instances to tile_example_fmha_fwd manually --- example/ck_tile/01_fmha/CMakeLists.txt | 17 ++++++++++------- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 5 ++--- example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 8 ++++---- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 17 +++++++---------- .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 6 +----- .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 6 +----- .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 6 +----- .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 6 +----- 8 files changed, 27 insertions(+), 44 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index ce914b92af..143491141b 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -7,7 +7,7 @@ if(NOT INST_TARGETS) endif() # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") +set(FMHA_FWD_KNOWN_APIS "fwd") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(BUILD_TESTING) @@ -44,7 +44,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} - --optdim 32,64,128,256 + --optdim 128 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS @@ -108,11 +108,16 @@ add_custom_command( set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") +file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" +) + message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS} ${FMHA_FWD_V3_INSTANCES}) set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${FMHA_FWD_V3_INSTANCES} PROPERTIES LANGUAGE HIP) set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") @@ -129,7 +134,7 @@ set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated -list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template --save-temps -Wno-gnu-line-marker) list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) # Allow comparing floating points directly in order to check sentinel values @@ -210,9 +215,6 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE fmha_fwd_v3.cpp ${FMHA_FWD_V3_INSTANCES} @@ -223,6 +225,7 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero -Wno-undefined-func-template --save-temps + -Wno-gnu-line-marker ) set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b44eb79021..90e93008db 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -977,14 +977,13 @@ def get_pipelines( ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip - """ + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now if (hdim, hdim_v) == (128, 128): # qr_async_trload_v3 only supports (generic) causal mask for mask in ["no", "causal"]: - pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "t", "t", + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_squant=squant, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip - """ return pipelines diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp index 041e04328d..13f26d07db 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp @@ -15,13 +15,13 @@ float fmha_fwd_v3(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::str { using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(config, args); + return fmha_fwd_(config, args); } else { using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(config, args); + return fmha_fwd_(config, args); } } else if(traits.data_type.compare("bf16") == 0) @@ -30,13 +30,13 @@ float fmha_fwd_v3(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::str { using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(config, args); + return fmha_fwd_(config, args); } else { using kernel_traits = fmha_fwd_v3_kernel_traits; - return fmha_fwd_v3_kernel_dispatch(config, args); + return fmha_fwd_(config, args); } } diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 451ebadd21..d072efc5fd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -21,13 +21,13 @@ #include "fmha_fwd_v3.hpp" #include "mask.hpp" -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - float fmha_fwd_v3_kernel_dispatch( \ - const ck_tile::stream_config& config, fmha_fwd_args args) \ - { \ - return fmha_fwd_v3_kernel_launch::type>(config, \ - args); \ +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ + template <> \ + float fmha_fwd_(const ck_tile::stream_config& config, \ + fmha_fwd_args args) \ + { \ + return fmha_fwd_v3_kernel_launch::type>( \ + config, args); \ } namespace ck_tile { @@ -199,7 +199,4 @@ float fmha_fwd_v3_kernel_launch(const ck_tile::stream_config& config, fmha_fwd_a return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); } -template -float fmha_fwd_v3_kernel_dispatch(const ck_tile::stream_config&, fmha_fwd_args); - } // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp index 0d199aa33f..2aa465c385 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp @@ -4,10 +4,6 @@ #include "fmha_fwd_v3.hpp" #include "fmha_fwd_v3_impl.hpp" -namespace ck_tile { - -using kernel_traits = fmha_fwd_v3_kernel_traits; +using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp index a371d74a80..a9f9e4d7ef 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp @@ -4,10 +4,6 @@ #include "fmha_fwd_v3.hpp" #include "fmha_fwd_v3_impl.hpp" -namespace ck_tile { - -using kernel_traits = fmha_fwd_v3_kernel_traits; +using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp index b0fbc88f78..425ba8f8fd 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp @@ -4,10 +4,6 @@ #include "fmha_fwd_v3.hpp" #include "fmha_fwd_v3_impl.hpp" -namespace ck_tile { - -using kernel_traits = fmha_fwd_v3_kernel_traits; +using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp index bd1860fb25..bce6d1842b 100644 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp @@ -4,10 +4,6 @@ #include "fmha_fwd_v3.hpp" #include "fmha_fwd_v3_impl.hpp" -namespace ck_tile { - -using kernel_traits = fmha_fwd_v3_kernel_traits; +using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile From d6a99c2c7f499ca0c7bff588de1016bb82c98984 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 7 Nov 2025 02:48:18 -0600 Subject: [PATCH 10/48] Remove unused function import --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 90e93008db..f3496227d2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -16,7 +16,6 @@ from codegen.cpp_symbol_map import ( LAYOUT_MAP, BIAS_CHECK_MAP, - get_mask_check_map, BOOL_MAP, PIPELINE_MAP, PIPELINE_ENUM_MAP, From 76b2bc0490e02d8449e4c9cea72ef1a7fcc2b9c5 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 7 Nov 2025 03:55:36 -0600 Subject: [PATCH 11/48] Undo irrelevant changes --- example/ck_tile/01_fmha/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 143491141b..5b860ee359 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -7,7 +7,7 @@ if(NOT INST_TARGETS) endif() # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(BUILD_TESTING) @@ -44,7 +44,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} - --optdim 128 + --optdim 32,64,128,256 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS @@ -134,7 +134,7 @@ set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated -list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template --save-temps -Wno-gnu-line-marker) +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) # Allow comparing floating points directly in order to check sentinel values From 260908ab79c348bc192c3d780e593c810758cb32 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 7 Nov 2025 04:00:43 -0600 Subject: [PATCH 12/48] Remove fwd v3 instances from tile_example_fmha_fwd --- example/ck_tile/01_fmha/CMakeLists.txt | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 5b860ee359..eea67dfc85 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -108,16 +108,11 @@ add_custom_command( set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) - message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS} ${FMHA_FWD_V3_INSTANCES}) +target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${FMHA_FWD_V3_INSTANCES} PROPERTIES LANGUAGE HIP) set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") @@ -215,6 +210,9 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" +) target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE fmha_fwd_v3.cpp ${FMHA_FWD_V3_INSTANCES} From 286a24bdcb465f8d994aa8074bdae465031c6728 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 7 Nov 2025 04:25:13 -0600 Subject: [PATCH 13/48] Finish fmha fwd v3 kernel instance codegen --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 75 +++++++++++++++++-- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f3496227d2..d95016b7cb 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -121,8 +121,73 @@ #endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ -FMHA_FWD_V3_KERNEL_HEADER = "// this is fmha fwd v3 kernel header\n" -FMHA_FWD_V3_KERNEL_BODY_TEMPLATE = "// this is fmha fwd v3 kernel body\n" +FMHA_FWD_V3_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd_v3_impl.hpp" +""" +FMHA_FWD_V3_KERNEL_BODY_TEMPLATE = """ +#include + +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) + +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdV3Traits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_lse}, + {F_occupancy}>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdV3PipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdV3Kernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +template<> +float fmha_fwd_(const ck_tile::stream_config& config, fmha_fwd_args args) +{{ + return fmha_fwd_v3_kernel_launch(config, args); +}} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) +""" FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" FMHA_FWD_API = """ @@ -1212,9 +1277,9 @@ def get_fwd_blobs( if hdim not in optdim_list: continue for tile, next_tile in zip(tiles, tiles[1:]): - assert next_tile.F_bm0 >= tile.F_bm0, ( - "Tiles must be ordered by increasing bm0" - ) + assert ( + next_tile.F_bm0 >= tile.F_bm0 + ), "Tiles must be ordered by increasing bm0" for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) From 006692fb9f98b4177e9a4ae4072456f14b573224 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 00:14:41 -0600 Subject: [PATCH 14/48] Fix formatting --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d95016b7cb..f2d694a4ea 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1277,9 +1277,9 @@ def get_fwd_blobs( if hdim not in optdim_list: continue for tile, next_tile in zip(tiles, tiles[1:]): - assert ( - next_tile.F_bm0 >= tile.F_bm0 - ), "Tiles must be ordered by increasing bm0" + assert next_tile.F_bm0 >= tile.F_bm0, ( + "Tiles must be ordered by increasing bm0" + ) for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) From 051a6be5f586ad2b6df5fdba6482750e89b7b3f8 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 01:06:14 -0600 Subject: [PATCH 15/48] Remove unused F_idx attribute --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 125 +++++++++--------- 1 file changed, 61 insertions(+), 64 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f2d694a4ea..69d09bf021 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -45,18 +45,18 @@ #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_dtype = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, +using fmha_trait = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, @@ -69,47 +69,47 @@ {F_occupancy}, {F_skip}>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, +using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask = {F_mask}; + +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, + fmha_variant, + fmha_mask, {F_trload}, - fmha_trait_{F_idx}>; + fmha_trait>; -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; -using fmha_epilogue_{F_idx} = +using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel; +using fmha_kernel = + ck_tile::FmhaFwdKernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; +using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}; + using k_ = fmha_kernel; if(s.log_level_ > 0) std::cout << ", {F_kname}" << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); @@ -131,59 +131,59 @@ #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_dtype = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdV3Traits<{F_spad}, +using fmha_trait = ck_tile::TileFmhaFwdV3Traits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_lse}, {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdV3PipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, +using fmha_mask = {F_mask}; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdV3PipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, {F_mode}, - fmha_mask_{F_idx}, - fmha_trait_{F_idx}>; + fmha_mask, + fmha_trait>; -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; -using fmha_epilogue_{F_idx} = +using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdV3Kernel; +using fmha_kernel = + ck_tile::FmhaFwdV3Kernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; +using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> -float fmha_fwd_(const ck_tile::stream_config& config, fmha_fwd_args args) +float fmha_fwd_(const ck_tile::stream_config& config, fmha_fwd_args args) {{ - return fmha_fwd_v3_kernel_launch(config, args); + return fmha_fwd_v3_kernel_launch(config, args); }} #endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) @@ -606,7 +606,6 @@ def name(self) -> str: @dataclass class FmhaFwdKernel: F_arch: ArchTrait - F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type F_mode: str # value from MODE_MAP @@ -618,7 +617,6 @@ class FmhaFwdKernel: def render(self) -> str: return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( - F_idx=self.F_idx, F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, @@ -752,7 +750,6 @@ def create_kernel( else FmhaFwdKernel ) return ctor( - F_idx=0, F_arch=arch, F_dtype=problem_ctx.dtype, F_mode=problem_ctx.mode, From 0b15146b506d4703e85093cac391e4f8e8e98ffd Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 02:04:16 -0600 Subject: [PATCH 16/48] Add is_generic_attention_mask<> traits --- include/ck_tile/ops/fmha/block/block_masking.hpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 2c45945fac..4d6764e8bd 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask mdiv y_ratio_mdiv; }; +template +struct is_generic_attention_mask : std::false_type +{ +}; + +template +struct is_generic_attention_mask> : std::true_type +{ +}; + +template +static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask::value; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask From a17699631e4dc1a05848988110757a84c4f542b5 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 01:59:54 -0600 Subject: [PATCH 17/48] Add constraints to the fmha fwd v3 pipeline --- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 5e2a4e898b..065a13a6c2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -261,12 +262,16 @@ struct BlockFmhaFwdV3Pipeline using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; + static_assert(is_generic_attention_mask_v); static_assert(std::is_same_v, "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); using BlockFmhaShape = ck_tile::remove_cvref_t; + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; @@ -277,14 +282,21 @@ struct BlockFmhaFwdV3Pipeline static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((kHasLogitsSoftCap == false && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + kStoreLSE == false), + "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this From 10ecccc6753ac0b954af550cd447507233705ad9 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 03:31:35 -0600 Subject: [PATCH 18/48] Unify traits & problem used for fmha fwd v3 --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 32 +++++--- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 78 ++++++++----------- .../pipeline/block_fmha_pipeline_problem.hpp | 43 ---------- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 ---- 4 files changed, 56 insertions(+), 113 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 69d09bf021..9093e040d4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -56,7 +56,7 @@ ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; -using fmha_trait = ck_tile::TileFmhaTraits<{F_spad}, +using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, @@ -90,14 +90,14 @@ fmha_variant, fmha_mask, {F_trload}, - fmha_trait>; + fmha_traits>; using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, {F_spad}, {F_dvpad}>>; using fmha_kernel = @@ -142,36 +142,48 @@ ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; -using fmha_trait = ck_tile::TileFmhaFwdV3Traits<{F_spad}, +using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, + {F_bias}, + false, {F_lse}, - {F_occupancy}>; + {F_dropout}, + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask = {F_mask}; -using fmha_pipeline_problem = ck_tile::BlockFmhaFwdV3PipelineProblem< +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, fmha_shape, {F_mode}, + fmha_variant, fmha_mask, - fmha_trait>; + {F_trload}, + fmha_traits>; using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, {F_spad}, {F_dvpad}>>; using fmha_kernel = diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index d072efc5fd..7be368d1e8 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -32,27 +32,6 @@ namespace ck_tile { -template -struct fmha_fwd_v3_problem_traits; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::half_t; - using acc_dtype = float; - using o_dtype = ck_tile::half_t; - using lse_dtype = float; -}; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::bf16_t; - using acc_dtype = float; - using o_dtype = ck_tile::bf16_t; - using lse_dtype = float; -}; - template using fmha_fwd_v3_kernel_traits = fmha_fwd_traits_<128, @@ -82,7 +61,7 @@ using fmha_fwd_v3_kernel_traits = template struct get_fmha_fwd_v3_kernel { - using data_type = KernelTraits::DataType; + using fmha_dtype = KernelTraits::DataType; static constexpr bool kIsGroupMode = KernelTraits::kIsGroupMode; // M0 N0 K0 N1 K1 @@ -102,36 +81,47 @@ struct get_fmha_fwd_v3_kernel fmha_warp_gemm_shape, KernelTraits::kIsVLayoutRowMajor>; - using fmha_traits = TileFmhaFwdV3Traits; + using fmha_traits = ck_tile::TileFmhaTraits; + + using fmha_variant = ck_tile::ComposedAttention; using fmha_mask = KernelTraits::FmhaMask; using fmha_pipeline_problem = - BlockFmhaFwdV3PipelineProblem::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::lse_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - fmha_shape, - kIsGroupMode, - fmha_mask, - fmha_traits>; + BlockFmhaPipelineProblem::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, + kIsGroupMode, + fmha_variant, + fmha_mask, + true, + fmha_traits>; using fmha_pipeline = BlockFmhaFwdV3Pipeline; using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, + Default2DEpilogueProblem::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, true, // kPadM true, // kPadM true // UseRawStore diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index cc0851efb3..8978e1b420 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdV3PipelineProblem -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; - - static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - - static constexpr bool kIsGroupMode = kIsGroupMode_; - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 59267fa3b1..183f0535aa 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -165,20 +165,4 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdV3Traits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadSeqLenK = kPadSeqLenK_; - static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - } // namespace ck_tile From 16d4573d60bfaab5e60218e1c1e08c8347622915 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 04:16:54 -0600 Subject: [PATCH 19/48] Unify kernel launch code for fmha fwd v2 & v3 --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 57 ++++++++++++++ example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 78 +++---------------- 3 files changed, 76 insertions(+), 67 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 9093e040d4..afa967078f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -195,7 +195,13 @@ template<> float fmha_fwd_(const ck_tile::stream_config& config, fmha_fwd_args args) {{ - return fmha_fwd_v3_kernel_launch(config, args); + using k_ = fmha_kernel; + if(config.log_level_ > 0) + std::cout << ", {F_kname}" << std::flush; + auto [kargs, grids] = fmha_fwd_v3_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(config, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} #endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index a952800806..137e5173e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -683,6 +683,63 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) +{ + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + + auto kargs = FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + static_cast(args.cu_seqlen_q_ptr), + static_cast(args.cu_seqlen_k_ptr)); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + template auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) { diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 7be368d1e8..ca54c9eca8 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -21,13 +21,18 @@ #include "fmha_fwd_v3.hpp" #include "mask.hpp" -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - float fmha_fwd_(const ck_tile::stream_config& config, \ - fmha_fwd_args args) \ - { \ - return fmha_fwd_v3_kernel_launch::type>( \ - config, args); \ +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ + template <> \ + float fmha_fwd_(const ck_tile::stream_config& config, \ + fmha_fwd_args args) \ + { \ + using kernel = typename ck_tile::get_fmha_fwd_v3_kernel::type; \ + auto [kargs, grids] = fmha_fwd_v3_create_kargs_and_grids(args); \ + const dim3 blocks = kernel::BlockSize(); \ + constexpr ck_tile::index_t kBlockPerCu = kernel::kBlockPerCu; \ + return ck_tile::launch_kernel(config, \ + ck_tile::make_kernel( \ + kernel{}, grids, blocks, 0, kargs)); \ } namespace ck_tile { @@ -130,63 +135,4 @@ struct get_fmha_fwd_v3_kernel using type = FmhaFwdV3Kernel; }; -template -float fmha_fwd_v3_kernel_launch(const ck_tile::stream_config& config, fmha_fwd_args args) -{ - /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly - /// maximizes the kernel's performance. - int remap_opt = 2; - if(args.mask_type != static_cast(mask_enum::no_mask) && - ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) - { - if(65536 <= args.seqlen_q) - { - remap_opt = 0; - } - else - { - remap_opt = 1; - } - } - - auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - nullptr, // lse_ptr - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - 0, // nhead_stride_lse - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - 0, // batch_stride_lse - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - remap_opt, - static_cast(args.cu_seqlen_q_ptr), - static_cast(args.cu_seqlen_k_ptr)); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; - - return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} - } // namespace ck_tile From 1810d6f07db6c9bf4727608fa7ceccdd07a3b7b6 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 10 Nov 2025 21:46:18 -0600 Subject: [PATCH 20/48] Unify kernel template selection logic --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index afa967078f..28771686e2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -100,8 +100,11 @@ typename FmhaFwdTypeConfig::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel = - ck_tile::FmhaFwdKernel; +using fmha_kernel = std::conditional_t< + {F_pipeline_enum} == ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3, + ck_tile::FmhaFwdV3Kernel, + ck_tile::FmhaFwdKernel +>; using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; @@ -186,8 +189,11 @@ typename FmhaFwdTypeConfig::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel = - ck_tile::FmhaFwdV3Kernel; +using fmha_kernel = std::conditional_t< + {F_pipeline_enum} == ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3, + ck_tile::FmhaFwdV3Kernel, + ck_tile::FmhaFwdKernel +>; using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; From 05ffeace60120d5743dabca37b9e209b337aa83b Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 00:13:55 -0600 Subject: [PATCH 21/48] Use same kernel codegen template for both v2 & v3 --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 128 +++--------------- 1 file changed, 20 insertions(+), 108 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 28771686e2..6ea451f76a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -100,11 +100,7 @@ typename FmhaFwdTypeConfig::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel = std::conditional_t< - {F_pipeline_enum} == ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3, - ck_tile::FmhaFwdV3Kernel, - ck_tile::FmhaFwdKernel ->; +using fmha_kernel = {F_kernel}; using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; @@ -115,7 +111,7 @@ using k_ = fmha_kernel; if(s.log_level_ > 0) std::cout << ", {F_kname}" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + auto [kargs, grids] = {F_kargs_creator}(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); @@ -124,95 +120,6 @@ #endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) """ -FMHA_FWD_V3_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_fwd_v3_impl.hpp" -""" -FMHA_FWD_V3_KERNEL_BODY_TEMPLATE = """ -#include - -#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) - -using fmha_dtype = {F_dtype}; - -using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_squant}, - {F_occupancy}, - {F_skip}>; - -using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; - -using fmha_mask = {F_mask}; - -using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape, - {F_mode}, - fmha_variant, - fmha_mask, - {F_trload}, - fmha_traits>; - -using fmha_pipeline = {F_pipeline}< - fmha_pipeline_problem>; - -using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel = std::conditional_t< - {F_pipeline_enum} == ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3, - ck_tile::FmhaFwdV3Kernel, - ck_tile::FmhaFwdKernel ->; - -using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; - -template<> -float fmha_fwd_(const ck_tile::stream_config& config, fmha_fwd_args args) -{{ - using k_ = fmha_kernel; - if(config.log_level_ > 0) - std::cout << ", {F_kname}" << std::flush; - auto [kargs, grids] = fmha_fwd_v3_create_kargs_and_grids(args); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(config, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -""" - FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" FMHA_FWD_API = """ #include @@ -639,6 +546,20 @@ class FmhaFwdKernel: _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + @classmethod + def _get_cpp_kernel_class_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "ck_tile::FmhaFwdV3Kernel" + else: + return "ck_tile::FmhaFwdKernel" + + @classmethod + def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "fmha_fwd_v3_create_kargs_and_grids" + else: + return "fmha_fwd_create_kargs_and_grids" + def render(self) -> str: return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( F_kname=self.name, @@ -678,8 +599,10 @@ def render(self) -> str: F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), + F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), ) @property @@ -726,12 +649,6 @@ def api_trait(self) -> FmhaFwdApiTrait: ) -@dataclass -class FmhaFwdV3Kernel(FmhaFwdKernel): - _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_V3_KERNEL_HEADER - _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_V3_KERNEL_BODY_TEMPLATE - - @dataclass class ProblemContext: dtype: str @@ -768,12 +685,7 @@ def is_compatible( def create_kernel( arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> FmhaFwdKernel: - ctor = ( - FmhaFwdV3Kernel - if kernel_ctx.pipeline.tag == "qr_async_trload_v3" - else FmhaFwdKernel - ) - return ctor( + return FmhaFwdKernel( F_arch=arch, F_dtype=problem_ctx.dtype, F_mode=problem_ctx.mode, From 7b9b7ee2dc8eb12e63f7905ec5a4c1db0b1d97d7 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 00:20:19 -0600 Subject: [PATCH 22/48] Rename api() property as render() method --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 6ea451f76a..cb51d16089 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -433,8 +433,7 @@ def register_traits(self, trait: FmhaFwdApiTrait) -> None: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - @property - def api(self) -> str: + def render(self) -> str: per_arch = str() for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): per_dtypes = str() @@ -1247,7 +1246,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.render()) def write_blobs( From 923a97a44715301c0084183dc852a2e9e45a5d2d Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 01:01:40 -0600 Subject: [PATCH 23/48] Allow specifying filter for fmha fwd api pool --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cb51d16089..ab1c00d8de 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -121,40 +121,46 @@ """ FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" -FMHA_FWD_API = """ +FMHA_FWD_API_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py #include #include -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ +#include "fmha_fwd.hpp" + +namespace { +bool get_num_cus(unsigned& num_cus) { int device; auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device"); return false; - }} + } - hipDeviceProp_t props{{}}; + hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device properties"); return false; - }} + } num_cus = props.multiProcessorCount; return true; -}} +} -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) { const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{ +} +} // namespace +""" +FMHA_FWD_API_BODY_TEMPLATE = """ +float fmha_fwd([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -433,7 +439,11 @@ def register_traits(self, trait: FmhaFwdApiTrait) -> None: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - def render(self) -> str: + def render(self, filter: Optional[Callable[[FmhaFwdApiTrait], bool]] = None) -> str: + accept_all = lambda _trait: True + if filter is None: + filter = accept_all + per_arch = str() for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): per_dtypes = str() @@ -444,7 +454,9 @@ def render(self) -> str: ): max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) inners = str() - for i_trait, trait in enumerate(pool_by_hdim): + for i_trait, trait in enumerate( + [trait for trait in pool_by_hdim if filter(trait)] + ): inners += FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_(i_trait), F_arch=arch, @@ -494,10 +506,7 @@ def render(self) -> str: F_arch=arch, F_dtype_case=indent(per_dtypes), ) - if not per_arch: - # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) + return FMHA_FWD_API_BODY_TEMPLATE.format(F_dispatch=indent(per_arch)) @dataclass @@ -1246,7 +1255,8 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.render()) + content = FMHA_FWD_API_HEADER + api_pool.render() + update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) def write_blobs( From be4d12336d14f8e67381f088db9ca4210c6cad86 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 02:50:51 -0600 Subject: [PATCH 24/48] Allow specifying function name when rendering api pool items --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ab1c00d8de..23bfa23229 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -159,8 +159,8 @@ } } // namespace """ -FMHA_FWD_API_BODY_TEMPLATE = """ -float fmha_fwd([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ +FMHA_FWD_API_FUNC_TEMPLATE = """ +float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -180,6 +180,8 @@ return r; }} """ +FMHA_FWD_API_FOOTER = """ +""" FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ {F_dtype_case} @@ -439,7 +441,9 @@ def register_traits(self, trait: FmhaFwdApiTrait) -> None: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - def render(self, filter: Optional[Callable[[FmhaFwdApiTrait], bool]] = None) -> str: + def render( + self, func_name, filter: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> str: accept_all = lambda _trait: True if filter is None: filter = accept_all @@ -506,7 +510,9 @@ def render(self, filter: Optional[Callable[[FmhaFwdApiTrait], bool]] = None) -> F_arch=arch, F_dtype_case=indent(per_dtypes), ) - return FMHA_FWD_API_BODY_TEMPLATE.format(F_dispatch=indent(per_arch)) + return FMHA_FWD_API_FUNC_TEMPLATE.format( + F_func_name=func_name, F_dispatch=indent(per_arch) + ) @dataclass @@ -1255,7 +1261,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - content = FMHA_FWD_API_HEADER + api_pool.render() + content = FMHA_FWD_API_HEADER + api_pool.render("fmha_fwd") + FMHA_FWD_API_FOOTER update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) From b66d3f5058259bd91df494a5d27905831a91e68a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 04:09:15 -0600 Subject: [PATCH 25/48] Separate fmha fwd v3 kernel dispatching logic from v2 --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 23bfa23229..688c271840 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -181,6 +181,9 @@ }} """ FMHA_FWD_API_FOOTER = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { + return autogen_fmha_fwd_v2(traits, args, config); +} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -1260,8 +1263,22 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: update_file(autogen_dir / kernel.filename, kernel.render()) -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - content = FMHA_FWD_API_HEADER + api_pool.render("fmha_fwd") + FMHA_FWD_API_FOOTER +def write_fwd_api( + api_pool: FmhaFwdApiPool, + autogen_dir: Path, +) -> None: + def accept_only_v3(trait: FmhaFwdApiTrait) -> bool: + return trait.pipeline_tag == "qr_async_trload_v3" + + def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: + return not accept_only_v3(trait) + + content = ( + FMHA_FWD_API_HEADER + + api_pool.render("autogen_fmha_fwd_v2", filter=accept_only_v2) + + api_pool.render("autogen_fmha_fwd_v3", filter=accept_only_v3) + + FMHA_FWD_API_FOOTER + ) update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) From 48487b53ff866c2233c76a4d95780797a2df3950 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 04:14:00 -0600 Subject: [PATCH 26/48] Remove lambda assignment --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 688c271840..b5ff7ed327 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -447,8 +447,11 @@ def register_traits(self, trait: FmhaFwdApiTrait) -> None: def render( self, func_name, filter: Optional[Callable[[FmhaFwdApiTrait], bool]] = None ) -> str: - accept_all = lambda _trait: True if filter is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + filter = accept_all per_arch = str() From fd8312ca3455c7f00fcd99aa5efcd5bbb6a1a985 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 11:40:49 -0600 Subject: [PATCH 27/48] Add simple v2/v3 dispatch logic --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 52 +++++++++++++++---- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b5ff7ed327..6fee447f07 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -160,6 +160,7 @@ } // namespace """ FMHA_FWD_API_FUNC_TEMPLATE = """ +namespace {{ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; @@ -179,11 +180,17 @@ {F_dispatch} return r; }} +}} // namespace """ -FMHA_FWD_API_FOOTER = """ -float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { - return autogen_fmha_fwd_v2(traits, args, config); -} +FMHA_FWD_API_FOOTER_TEMPLATE = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ + const bool can_dispatch_v3 = false; + if ({F_is_v3_enabled} and can_dispatch_v3) {{ + return fmha_fwd_v3(traits, args, config); + }} else {{ + return fmha_fwd_v2(traits, args, config); + }} +}} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -444,16 +451,35 @@ def register_traits(self, trait: FmhaFwdApiTrait) -> None: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) + def get_num_traits( + self, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> int: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + return sum( + sum(1 for trait in pool_by_hdim if filter_fn(trait)) + for pool_by_arch in self.pool.values() + for pool_by_dtype in pool_by_arch.values() + for pool_by_hdim in pool_by_dtype.values() + ) + def render( - self, func_name, filter: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + self, func_name, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None ) -> str: - if filter is None: + if filter_fn is None: def accept_all(trait: FmhaFwdApiTrait) -> bool: return True - filter = accept_all + filter_fn = accept_all + # TODO: Stop generating empty if-clauses. To fix this, skip iterating over + # dictionaries that have no traits, and avoid assigning i_* to them. per_arch = str() for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): per_dtypes = str() @@ -465,7 +491,7 @@ def accept_all(trait: FmhaFwdApiTrait) -> bool: max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) inners = str() for i_trait, trait in enumerate( - [trait for trait in pool_by_hdim if filter(trait)] + [trait for trait in pool_by_hdim if filter_fn(trait)] ): inners += FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_(i_trait), @@ -1278,9 +1304,13 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: content = ( FMHA_FWD_API_HEADER - + api_pool.render("autogen_fmha_fwd_v2", filter=accept_only_v2) - + api_pool.render("autogen_fmha_fwd_v3", filter=accept_only_v3) - + FMHA_FWD_API_FOOTER + + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2) + + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3) + + FMHA_FWD_API_FOOTER_TEMPLATE.format( + F_is_v3_enabled=BOOL_MAP[ + 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + ] + ) ) update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) From 0a3cfe1f141976a8d795b8eea6f5738ceb5708a6 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 12:01:33 -0600 Subject: [PATCH 28/48] Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 6fee447f07..0055df3288 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -478,15 +478,25 @@ def accept_all(trait: FmhaFwdApiTrait) -> bool: filter_fn = accept_all - # TODO: Stop generating empty if-clauses. To fix this, skip iterating over - # dictionaries that have no traits, and avoid assigning i_* to them. + def has_traits(node) -> bool: + """Recursively traverse nested OrderedDicts and lists to determine if any FmhaFwdApiTrait satisfies filter_fn().""" + if isinstance(node, list): + return any(filter_fn(elem) for elem in node) + elif isinstance(node, OrderedDict): + return any(has_traits(val) for val in node.values()) + return False + per_arch = str() - for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + for i_arch, (arch, pool_by_arch) in enumerate( + item for item in self.pool.items() if has_traits(item[1]) + ): per_dtypes = str() - for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + for i_dtype, (dtype, pool_by_dtype) in enumerate( + item for item in pool_by_arch.items() if has_traits(item[1]) + ): per_hdim_case = str() for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( - pool_by_dtype.items() + item for item in pool_by_dtype.items() if has_traits(item[1]) ): max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) inners = str() From 9da8cbbc00281e73e232021fcbd35653b72ee3c1 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 12:11:07 -0600 Subject: [PATCH 29/48] Use "".join() to concatenate fmha fwd api string content --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 0055df3288..5b15d6545f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1312,15 +1312,17 @@ def accept_only_v3(trait: FmhaFwdApiTrait) -> bool: def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: return not accept_only_v3(trait) - content = ( - FMHA_FWD_API_HEADER - + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2) - + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3) - + FMHA_FWD_API_FOOTER_TEMPLATE.format( - F_is_v3_enabled=BOOL_MAP[ - 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) - ] - ) + content = "".join( + [ + FMHA_FWD_API_HEADER, + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), + FMHA_FWD_API_FOOTER_TEMPLATE.format( + F_is_v3_enabled=BOOL_MAP[ + 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + ] + ), + ] ) update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) From 679387703b4263eebd930adcc7ea163a72ddee81 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 21:10:28 -0600 Subject: [PATCH 30/48] Add more feature checks for fmha fwd v3 pipeline --- .../ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 065a13a6c2..f4d46a27ba 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -293,9 +293,10 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - - static_assert((kHasLogitsSoftCap == false && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && - kStoreLSE == false), + static constexpr bool kDoFp8StaticQuant = Problem::kDoFp8StaticQuant; + static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; + static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + !kStoreLSE && !kHasDropout && !kDoFp8StaticQuant && !kSkipMinSeqlenQ), "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) From 772c30f22e766aa974efea05ca79dd9c3c9a6ce5 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 21:13:36 -0600 Subject: [PATCH 31/48] Check features before dispatch to fmha_fwd_v3() --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 5b15d6545f..cac2ad6575 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -184,7 +184,14 @@ """ FMHA_FWD_API_FOOTER_TEMPLATE = """ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ - const bool can_dispatch_v3 = false; + const bool is_swa = (0 < args.window_size_left) or (0 < args.window_size_right); + const bool can_dispatch_v3 = + (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and + (not traits.has_dropout) and (not traits.do_fp8_static_quant) and + (not traits.skip_min_seqlen_q) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and + (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ From eebe510849674c50b792bfb290684fe8cc7dda61 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Nov 2025 21:41:39 -0600 Subject: [PATCH 32/48] Add more feature checks for fmha_fwd_v3() --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cac2ad6575..28eb9c44b1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -184,14 +184,17 @@ """ FMHA_FWD_API_FOOTER_TEMPLATE = """ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ + const std::string device_name = ck_tile::get_device_name(); + const bool is_swa = (0 < args.window_size_left) or (0 < args.window_size_right); const bool can_dispatch_v3 = + (device_name.compare(0, 6, "gfx950") == 0) and (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + (not traits.is_group_mode) and traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and (not traits.has_dropout) and (not traits.do_fp8_static_quant) and (not traits.skip_min_seqlen_q) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and - (args.hdim_q == 128) and (args.hdim_v == 128); + (args.hdim_q == 128) and (args.hdim_v == 128) and (4096 <= args.max_seqlen_q); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ From 17308757054a72b8f2b99303b3d81b9e54183c5c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Nov 2025 02:03:01 -0600 Subject: [PATCH 33/48] Add missing filter call --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 28eb9c44b1..f7b3dc223d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -508,7 +508,9 @@ def has_traits(node) -> bool: for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( item for item in pool_by_dtype.items() if has_traits(item[1]) ): - max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) + max_bm0 = max( + (t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0 + ) inners = str() for i_trait, trait in enumerate( [trait for trait in pool_by_hdim if filter_fn(trait)] From a62afeed109fec6e3148b5f58a73fc1b191852b5 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Nov 2025 04:46:18 -0600 Subject: [PATCH 34/48] Use Tuple to reserve the dtype orders --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f7b3dc223d..4c514a1341 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -885,16 +885,14 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) - _DT_FP32 = frozenset({"fp32"}) - _DT_FP16_BF16 = frozenset({"fp16", "bf16"}) - _DT_FP8_FP8BF16 = frozenset({"fp8", "fp8bf16"}) - _DT_FP8FP32 = frozenset({"fp8fp32"}) + _DT_FP32 = ("fp32",) + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) @classmethod - def supported_dtypes(cls) -> frozenset[str]: - return frozenset().union( - cls._DT_FP32, cls._DT_FP16_BF16, cls._DT_FP8_FP8BF16, cls._DT_FP8FP32 - ) + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP32 + cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 # TODO: design a more practical way to do it # this is current supported tile size per hdim @@ -991,7 +989,7 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in cls._DT_FP8_FP8BF16 | cls._DT_FP8FP32: + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() @@ -1055,15 +1053,13 @@ def get_pipelines( class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") - _DT_FP16_BF16 = frozenset({"fp16", "bf16"}) - _DT_FP8_FP8BF16 = frozenset({"fp8", "fp8bf16"}) - _DT_FP8FP32 = frozenset({"fp8fp32"}) + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) @classmethod - def supported_dtypes(cls) -> frozenset[str]: - return frozenset().union( - cls._DT_FP16_BF16, cls._DT_FP8_FP8BF16, cls._DT_FP8FP32 - ) + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 @classmethod def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: From 9c8922057cfe6b658d6faa54e2664affe4741ff6 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Nov 2025 05:05:03 -0600 Subject: [PATCH 35/48] Fix wrong pipeline matching logic --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 4c514a1341..12bf00f4bc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -815,9 +815,8 @@ def check_hdim_tile( ) -> bool: if problem_ctx.dtype != "fp32": # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if ( - kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES - and ( + if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( + ( (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) and kernel_ctx.tile.F_bn0 != 128 ) From 23c00220ec34beddb1b42cd5b16220963a405026 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Nov 2025 22:26:15 -0600 Subject: [PATCH 36/48] Add fmha fwd v3 group mode instances --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 18 +-- example/ck_tile/01_fmha/fmha_fwd.hpp | 105 +++++++++----- .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 132 +++++++++++------- 3 files changed, 155 insertions(+), 100 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 12bf00f4bc..7bcd0ec7c1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -190,11 +190,11 @@ const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - (not traits.is_group_mode) and traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and (not traits.has_dropout) and (not traits.do_fp8_static_quant) and (not traits.skip_min_seqlen_q) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and - (args.hdim_q == 128) and (args.hdim_v == 128) and (4096 <= args.max_seqlen_q); + (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ @@ -867,15 +867,7 @@ def check_tile_pipeline( is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" return is_v3_dedicated_tile == is_v3_pipeline - # qr_async_trload_v3 only support batch mode for now - def check_mode_pipeline( - problem_ctx: ProblemContext, kernel_ctx: KernelContext - ) -> bool: - if kernel_ctx.pipeline.tag == "qr_async_trload_v3": - return problem_ctx.mode == "batch" - return True - - rules.extend([check_tile_pipeline, check_mode_pipeline]) + rules.extend([check_tile_pipeline]) return rules @@ -1326,7 +1318,9 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), FMHA_FWD_API_FOOTER_TEMPLATE.format( F_is_v3_enabled=BOOL_MAP[ - 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + # NOTE: enable v3 pipelines when ready + # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + False ] ), ] diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 137e5173e3..7c11da45ef 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -702,40 +702,77 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) } } - auto kargs = FmhaKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - nullptr, // lse_ptr - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - 0, // nhead_stride_lse - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - 0, // batch_stride_lse - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - remap_opt, - static_cast(args.cu_seqlen_q_ptr), - static_cast(args.cu_seqlen_k_ptr)); - - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); + auto kargs = [&] { + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + else + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); return ck_tile::make_tuple(kargs, grids); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index e9115b14df..29f65993fe 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -103,8 +103,8 @@ struct FmhaFwdV3Kernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -114,12 +114,13 @@ struct FmhaFwdV3Kernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; // Optional cumulative padded sequence starts (including PAD tokens) // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -156,8 +157,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -199,8 +200,8 @@ struct FmhaFwdV3Kernel kargs.batch_stride_lse = batch_stride_lse; } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -213,6 +214,7 @@ struct FmhaFwdV3Kernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -232,8 +234,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -258,6 +260,7 @@ struct FmhaFwdV3Kernel {}, // placeholder for lse reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(kHasMask) @@ -273,30 +276,29 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - // TODO: this may need tuning - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1)); } else { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), + batch_size); } } @@ -344,13 +346,20 @@ struct FmhaFwdV3Kernel // FmhaPipeline::kN1); // assume that num_tile_n1 is always 1 - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { const index_t i_nhead = blockIdx.x; - const index_t i_block = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_batch = blockIdx.y; + const index_t i_block = blockIdx.z; - return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } else { @@ -358,7 +367,14 @@ struct FmhaFwdV3Kernel const index_t i_block = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } } @@ -390,32 +406,36 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + batch_offset_o = query_start * kargs.stride_o; + // real logical lengths (exclude PAD) + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -427,10 +447,14 @@ struct FmhaFwdV3Kernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else @@ -450,10 +474,10 @@ struct FmhaFwdV3Kernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } From 6526b5956da41820a6d3440188b0d04d665fda37 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 13 Nov 2025 00:52:03 -0600 Subject: [PATCH 37/48] Add functor_transform<> --- .../core/algorithm/coordinate_transform.hpp | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 7511413bba..e49b40f01f 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing& printf("}"); } +template +struct functor_transform : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + Functor functor_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr functor_transform() = default; + + CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor, + const LowLength& low_length) + : functor_{functor}, up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = functor_(idx_up[number<0>{}]); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + calculate_lower_index(idx_low, up_idx); + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // Note: When using functor_transform, ensure that the transformed coordinates + // are always valid for vectorized load/store operations. + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } +}; + //******************************************************************************************************* template @@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le return offset{low_length, offset_length}; } +template +CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor, + const LowLength& low_length) +{ + return functor_transform{functor, low_length}; +} + } // namespace ck_tile #include "ck_tile/core/algorithm/indexing_adaptor.hpp" From 291cea6a47a6873d03e825779480af01c82a3e0c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 13 Nov 2025 00:53:29 -0600 Subject: [PATCH 38/48] Add type constraints to make_tile_window() --- include/ck_tile/core/tensor/tile_window.hpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 1123ce7604..9e243197fe 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1205,7 +1205,9 @@ struct tile_window_with_static_lengths } }; -template +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1252,7 +1254,10 @@ make_tile_window(const tile_window_with_static_lengths +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution, From f4d92f1b04d3b2675cef1903852ceeebfa9c63e2 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 13 Nov 2025 01:49:59 -0600 Subject: [PATCH 39/48] Remove fmha fwd v3 example --- example/ck_tile/01_fmha/CMakeLists.txt | 35 - .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 624 ------------------ example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 46 -- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 18 - example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 138 ---- .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 9 - .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 9 - .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 9 - .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 9 - 9 files changed, 897 deletions(-) delete mode 100644 example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.hpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index eea67dfc85..7c3e58ed55 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -204,41 +204,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -# add fmha_fwd_v3 example -set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") - -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE - fmha_fwd_v3.cpp - ${FMHA_FWD_V3_INSTANCES} -) - -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -fgpu-flush-denormals-to-zero - -Wno-undefined-func-template - --save-temps - -Wno-gnu-line-marker -) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) - -check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) -if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -mllvm --amdgpu-disable-packed-fp32=1 - ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS - -DCK_TILE_DISABLE_PACKED_FP32=1 - ) -endif() - -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp deleted file mode 100644 index c713560045..0000000000 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ /dev/null @@ -1,624 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -auto parse_cmd_args(int argc, char* argv[]) -> std::pair -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("prec", "fp16", "data type. fp16/bf16") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q & k") - .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") - .insert("iperm", - "0", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "0", "permute output") - .insert("causal", "0", "0: no mask, 1: causal mask") - .insert("v", "1", "0:no verify, 1:verify") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "30", "number of iterations to benchmark the kernel") - // Optional effective seqlen override (exclude PAD) for batch mode - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); - - bool result = arg_parser.parse(argc, argv); - return std::make_pair(result, arg_parser); -} - -enum class TensorLayout -{ - bhsd, - bshd, -}; - -std::ostream& operator<<(std::ostream& stream, TensorLayout layout) -{ - switch(layout) - { - case TensorLayout::bhsd: return stream << "bhsd"; - case TensorLayout::bshd: return stream << "bshd"; - default: return stream << "unknown"; - } -} - -struct Problem -{ - explicit Problem(const ck_tile::ArgParser& args) - { - prec = args.get_str("prec") == "fp16" ? "fp16" : "bf16"; - batch = args.get_int("b"); - seqlen_q = args.get_int("s"); - seqlen_k = args.get_int("s_k"); - if(seqlen_k < 0) - { - seqlen_k = seqlen_q; - } - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); - if(nhead_kv < 0) - { - nhead_kv = nhead_q; - } - hdim = args.get_int("d"); - softmax_scale = args.get_float("scale_s"); - if(softmax_scale == .0f) - softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); - - const auto is_causal = args.get_bool("causal"); - if(is_causal) - { - mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); - } - else - { - mask = mask_info::decode("0", seqlen_q, seqlen_k); - } - - input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - q_eff_lens = args.get_int_vec("q_eff_lens"); - kv_eff_lens = args.get_int_vec("kv_eff_lens"); - } - - std::vector get_query_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - std::vector get_key_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_value_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_output_shape() const - { - if(output_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - std::string prec; - ck_tile::index_t batch; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_kv; - ck_tile::index_t hdim; - float softmax_scale; - mask_info mask; - TensorLayout input_layout; - TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; -}; - -struct RunConfig -{ - explicit RunConfig(const ck_tile::ArgParser& args) - { - seed = args.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } - - kernel_warmup = args.get_int("warmup"); - kernel_repeat = args.get_int("repeat"); - verify = args.get_bool("v"); - } - - std::optional seed; - int kernel_warmup; - int kernel_repeat; - bool verify; -}; - -template -auto generate_qkv(const Problem& problem, - [[maybe_unused]] std::optional seed = std::nullopt) - -> std::tuple, - ck_tile::HostTensor, - ck_tile::HostTensor> -{ - ck_tile::HostTensor q(problem.get_query_shape()); - ck_tile::HostTensor k(problem.get_key_shape()); - ck_tile::HostTensor v(problem.get_value_shape()); - - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); - - return std::make_tuple(q, k, v); -} - -namespace host { -template -CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, - const ck_tile::HostTensor& k_bshd, - const ck_tile::HostTensor& v_bshd, - const mask_info& mask, - ck_tile::HostTensor& o_bshd, - const QElementOp& q_element_op = {}, - const KElementOp& k_element_op = {}, - const VElementOp& v_element_op = {}, - const SAccElementOp& s_acc_element_op = {}) -{ - const int batch_size = q_bshd.mDesc.get_lengths()[0]; - const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; - const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; - const int nhead_q = q_bshd.mDesc.get_lengths()[2]; - const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; - const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; - const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - - const int nr = nhead_q / nhead_kv; - - ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); - ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // do computation for each batch - for(int b = 0; b < batch_size; ++b) - { - // copy per-batch data from input tensors - // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // clang-format on - ck_tile::reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, ck_tile::identity{}); - - ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - - // copy resulting per-batch data to the output tensor - o_host_ref.ForEach( - [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - } -} -} // namespace host - -template -bool run_impl(const Problem& problem, const RunConfig& run_config) -{ - auto [q, k, v] = generate_qkv(problem, run_config.seed); - - ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); - /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v - ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q.data()); - k_buf.ToDevice(k.data()); - v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - - fmha_fwd_traits traits{}; - traits.hdim_q = problem.hdim; - traits.hdim_v = problem.hdim; - traits.data_type = problem.prec; - traits.is_v_rowmajor = true; - traits.is_group_mode = false; - traits.has_logits_soft_cap = false; - traits.mask_type = mask_enum::mask_bottom_right; - traits.bias_type = bias_enum::no_bias; - traits.has_lse = false; - traits.do_fp8_static_quant = false; - - fmha_fwd_args args{}; - args.batch = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; - args.nhead_q = problem.nhead_q; - args.nhead_k = problem.nhead_kv; - args.hdim_q = problem.hdim; - args.hdim_v = problem.hdim; - args.scale_s = problem.softmax_scale; - - args.window_size_left = problem.mask.left; - args.window_size_right = problem.mask.right; - args.mask_type = static_cast(problem.mask.type); - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.q_ptr = q_buf.GetDeviceBuffer(); - args.stride_q = - problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_q = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; - args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_k = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.v_ptr = v_buf.GetDeviceBuffer(); - args.stride_v = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_v = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o = - problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_o = problem.output_layout == TensorLayout::bshd - ? problem.hdim - : problem.seqlen_q * problem.hdim; - args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_k_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - - ck_tile::stream_config stream_config{nullptr, - true, - /*log_level=*/0, - run_config.kernel_warmup, - run_config.kernel_repeat}; - - float time = ck_tile::fmha_fwd_v3(traits, args, stream_config); - if(time < 0.f) - { - std::cerr << "faild to run fmha_fwd_v3()" << std::endl; - return false; - } - - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); - float tflops = static_cast(flop) / 1.e9 / time; - - std::cout << "[" << problem.prec << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } - std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed - << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; - - if(!run_config.verify) - { - return true; - } - - // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } - - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } - - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } - - ck_tile::HostTensor o(problem.get_output_shape()); - o_buf.FromDevice(o.data()); - - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); -} - -int main(int argc, char* argv[]) -{ - auto [parse_result, args] = parse_cmd_args(argc, argv); - if(!parse_result) - { - std::cerr << "failed to parse command line arguments" << std::endl; - } - - Problem problem(args); - RunConfig run_config(args); - - const auto run = [&] { - if(problem.prec == "fp16") - { - return run_impl(problem, run_config); - } - else - { - return run_impl(problem, run_config); - } - }; - - return !run(); -} diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp deleted file mode 100644 index 13f26d07db..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" -#include "mask.hpp" - -namespace ck_tile { - -float fmha_fwd_v3(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) -{ - if(traits.data_type.compare("fp16") == 0) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = fmha_fwd_v3_kernel_traits; - - return fmha_fwd_(config, args); - } - else - { - using kernel_traits = fmha_fwd_v3_kernel_traits; - - return fmha_fwd_(config, args); - } - } - else if(traits.data_type.compare("bf16") == 0) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = fmha_fwd_v3_kernel_traits; - - return fmha_fwd_(config, args); - } - else - { - using kernel_traits = fmha_fwd_v3_kernel_traits; - - return fmha_fwd_(config, args); - } - } - - return -1.; -} - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp deleted file mode 100644 index c3a0d0d8f3..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/host/stream_config.hpp" - -#include "fmha_fwd.hpp" - -namespace ck_tile { - -float fmha_fwd_v3(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp deleted file mode 100644 index ca54c9eca8..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ /dev/null @@ -1,138 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck_tile/core/numeric/bfloat16.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/container/sequence.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" - -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - float fmha_fwd_(const ck_tile::stream_config& config, \ - fmha_fwd_args args) \ - { \ - using kernel = typename ck_tile::get_fmha_fwd_v3_kernel::type; \ - auto [kargs, grids] = fmha_fwd_v3_create_kargs_and_grids(args); \ - const dim3 blocks = kernel::BlockSize(); \ - constexpr ck_tile::index_t kBlockPerCu = kernel::kBlockPerCu; \ - return ck_tile::launch_kernel(config, \ - ck_tile::make_kernel( \ - kernel{}, grids, blocks, 0, kargs)); \ - } - -namespace ck_tile { - -template -using fmha_fwd_v3_kernel_traits = - fmha_fwd_traits_<128, - DataType, - kIsGroupMode, - 256, - 32, - 128, - 128, - 32, - 128, - true, - ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3, - false, - ck_tile::GenericAttentionMask, - ck_tile::BlockAttentionBiasEnum::NO_BIAS, - false, - false, - false, - true, - true, - false, - false, - true, - false>; - -template -struct get_fmha_fwd_v3_kernel -{ - using fmha_dtype = KernelTraits::DataType; - static constexpr bool kIsGroupMode = KernelTraits::kIsGroupMode; - - // M0 N0 K0 N1 K1 - using fmha_block_tile = sequence; - using fmha_warp_gemm_shape = sequence<32, 32, 16>; - using fmha_block_warps = sequence<8, 1, 1>; - - using fmha_shape = TileFmhaShape; - - using fmha_traits = ck_tile::TileFmhaTraits; - - using fmha_variant = ck_tile::ComposedAttention; - - using fmha_mask = KernelTraits::FmhaMask; - - using fmha_pipeline_problem = - BlockFmhaPipelineProblem::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape, - kIsGroupMode, - fmha_variant, - fmha_mask, - true, - fmha_traits>; - - using fmha_pipeline = BlockFmhaFwdV3Pipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - true, // kPadM - true, // kPadM - true // UseRawStore - >>; - - using type = FmhaFwdV3Kernel; -}; - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp deleted file mode 100644 index 2aa465c385..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp deleted file mode 100644 index a9f9e4d7ef..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp deleted file mode 100644 index 425ba8f8fd..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp deleted file mode 100644 index bce6d1842b..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -using kernel_traits = ck_tile::fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) From 2df5019fc08ad7987305a638184a3419f5c67b23 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 13 Nov 2025 02:36:13 -0600 Subject: [PATCH 40/48] Fix wrong product(aiter mha_fwd()) config --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 7bcd0ec7c1..d4bdda691d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1173,7 +1173,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] - cond &= problem_ctx.mode == "group" + cond &= problem_ctx.mode == "batch" cond &= kernel_ctx.pipeline.F_vlayout == "row" if problem_ctx.dtype == "fp8bf16": cond &= problem_ctx.hdim == 128 From 1df098d7cc9efc293aafc1ed9ccded3e597b9c98 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 16 Nov 2025 03:07:37 -0600 Subject: [PATCH 41/48] Fix wrong fmha fwd v2/v3 selection logic --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d4bdda691d..a825ac04a3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -186,15 +186,15 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ const std::string device_name = ck_tile::get_device_name(); - const bool is_swa = (0 < args.window_size_left) or (0 < args.window_size_right); + const bool is_swa = (traits.mask_type != mask_enum::no_mask) and + ((0 < args.window_size_left) or (0 < args.window_size_right)); const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and - (not traits.has_dropout) and (not traits.do_fp8_static_quant) and - (not traits.skip_min_seqlen_q) and (not is_swa) and (args.nhead_q % args.nhead_k == 0) and - (args.hdim_q == 128) and (args.hdim_v == 128); + (not traits.has_dropout) and (not traits.do_fp8_static_quant) and (not is_swa) and + (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ From 8e0d9dd8e0cd61c96cf8164225fd506e65910419 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 17 Nov 2025 00:50:07 -0600 Subject: [PATCH 42/48] Fix formatting --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index a825ac04a3..1bbc7da571 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1095,7 +1095,7 @@ def get_pipelines( ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - elif dtype in cls._DT_FP8_FP8BF16 | cls._DT_FP8FP32: + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() From cf1f13554838e9c94fb445d40d461c5d9c329666 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 3 Dec 2025 01:23:58 -0600 Subject: [PATCH 43/48] Add comment to warning v3 kernel users --- include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 1 + .../ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index be015d45dd..36280c657c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -12,6 +12,7 @@ namespace ck_tile { +/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and instruction scheduling optimizations. template struct FmhaFwdV3Kernel { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index ded385fbfb..28f515d106 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -247,6 +248,7 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) } } // namespace detail +/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and instruction scheduling optimizations. template struct BlockFmhaFwdV3Pipeline { @@ -293,10 +295,10 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr bool kDoFp8StaticQuant = Problem::kDoFp8StaticQuant; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && - !kStoreLSE && !kHasDropout && !kDoFp8StaticQuant && !kSkipMinSeqlenQ), + !kStoreLSE && !kHasDropout && (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) From 608a25380668b242e1f0d71e0d35edb80fdb8cb3 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 3 Dec 2025 01:29:42 -0600 Subject: [PATCH 44/48] Fix wrong codegen logics --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index fa94e536a8..8ade97aa4a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -195,8 +195,9 @@ (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and - (not traits.has_dropout) and (not traits.do_fp8_static_quant) and (not is_swa) and - (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); + (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and + (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and + (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ @@ -881,12 +882,19 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): _DT_FP32 = ("fp32",) _DT_FP16_BF16 = ("fp16", "bf16") - _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8 = ("fp8",) + _DT_FP8BF16 = ("fp8bf16",) _DT_FP8FP32 = ("fp8fp32",) @classmethod def supported_dtypes(cls) -> Tuple[str]: - return cls._DT_FP32 + cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 + return ( + cls._DT_FP32 + + cls._DT_FP16_BF16 + + cls._DT_FP8 + + cls._DT_FP8BF16 + + cls._DT_FP8FP32 + ) # TODO: design a more practical way to do it # this is current supported tile size per hdim @@ -921,7 +929,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype in cls._DT_FP8_FP8BF16: + elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], @@ -983,7 +991,7 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: + elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], @@ -1186,7 +1194,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= problem_ctx.mode == "batch" cond &= kernel_ctx.pipeline.F_vlayout == "row" if problem_ctx.dtype == "fp8bf16": - cond &= problem_ctx.hdim == 128 + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 256 return cond return Product(name="Aiter(mha_fwd) integration", rule=fit) @@ -1198,7 +1206,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= problem_ctx.mode == "group" cond &= kernel_ctx.pipeline.F_vlayout == "row" if problem_ctx.dtype == "fp8bf16": - cond &= problem_ctx.hdim == 128 + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 256 return cond return Product(name="Aiter(mha_varlen_fwd) integration", rule=fit) @@ -1209,16 +1217,16 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] cond &= kernel_ctx.pipeline.F_vlayout == "row" if problem_ctx.dtype == "fp8bf16": - cond &= problem_ctx.hdim == 128 + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 256 return cond return Product(name="aiter::mha_fwd C++ api integration", rule=fit) elif receipt == 888: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: - cond = problem_ctx.dtype in ["fp8", "fp8bf16", "fp8fp32"] + cond = problem_ctx.dtype in ["fp8bf16", "fp8fp32"] cond &= kernel_ctx.pipeline.F_vlayout == "row" - cond &= problem_ctx.hdim == 128 + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 256 return cond return Product(name="receipt = 888", rule=fit) From 02ed663352989a1b4e5e2a0e6971c3b3ac70d3b0 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 3 Dec 2025 01:44:45 -0600 Subject: [PATCH 45/48] Remove unnecessary param --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 8ade97aa4a..9161d915c1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -743,14 +743,7 @@ def is_compatible( problem_ctx: ProblemContext, kernel_ctx: KernelContext, rules: Iterable[CompatibilityRule], - *, - short_circuit: bool = True, ) -> bool: - if short_circuit: - for rule in rules: - if not rule(problem_ctx, kernel_ctx): - return False - return True return all(rule(problem_ctx, kernel_ctx) for rule in rules) @@ -1299,9 +1292,7 @@ def get_fwd_blobs( rules = factory.get_rules() product = get_product(receipt) - if not is_compatible( - problem_ctx, kernel_ctx, [*rules, product], short_circuit=True - ): + if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]): continue k = create_kernel(factory.arch, problem_ctx, kernel_ctx) From 0e29033ccf877546d039c2e413fb556b813b19bb Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 3 Dec 2025 03:42:43 -0600 Subject: [PATCH 46/48] Fix format --- include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 3 ++- .../ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | 7 +++++-- include/ck_tile/remod.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 36280c657c..f981c54bd8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -12,7 +12,8 @@ namespace ck_tile { -/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and instruction scheduling optimizations. +/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct FmhaFwdV3Kernel { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 28f515d106..68ec349694 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -248,7 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) } } // namespace detail -/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and instruction scheduling optimizations. +/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct BlockFmhaFwdV3Pipeline { @@ -298,7 +299,9 @@ struct BlockFmhaFwdV3Pipeline static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && - !kStoreLSE && !kHasDropout && (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), + !kStoreLSE && !kHasDropout && + (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && + !kSkipMinSeqlenQ), "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index affa6d987b..aeec7bd471 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -90,7 +90,7 @@ def gen_header(hpath, include_list): # formatting format_procs = [] for x in all_files: - dos2unix = f"python -m dos2unix {str(x)} {str(x)}" + dos2unix = f"python3 -m dos2unix {str(x)} {str(x)}" clang_format = f"clang-format -style=file -i {str(x)}" # One process to avoid race conditions. cmd = f"{dos2unix} && {clang_format}" From 162481919af1ecc459eb4c77ecce3c41dc55c3aa Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 4 Dec 2025 11:08:31 -0600 Subject: [PATCH 47/48] Add logits soft-capping support for fmha fwd v3 pipeline (WIP) --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 13 ++- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 + .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 83 +++++++++++++++--- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 85 ++++++++++++++++--- 4 files changed, 154 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c00bdcea3b..0494193939 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -201,11 +201,10 @@ const bool can_dispatch_v3 = (device_name.compare(0, 6, "gfx950") == 0) and (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and - traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and - (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and - (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and - (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and - (args.hdim_v == 128); + traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and + (not traits.has_lse) and (not traits.has_dropout) and + (traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and + (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128); if ({F_is_v3_enabled} and can_dispatch_v3) {{ return fmha_fwd_v3(traits, args, config); }} else {{ @@ -1054,9 +1053,9 @@ def get_pipelines( # qr_async_trload_v3 only supports hdim=hdim_v=128 for now if (hdim, hdim_v) == (128, 128): # qr_async_trload_v3 only supports (generic) causal mask - for mask in ["no", "causal"]: + for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 002d0a1035..bef1cbd006 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -722,6 +722,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -752,6 +753,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index f981c54bd8..7bc7bcfedc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" #include #include @@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; - using FmhaMask = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; template // to avoid duplicated base class prblem, introduce an template @@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_lse = 0; }; + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -111,8 +137,8 @@ struct FmhaFwdV3Kernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -127,6 +153,13 @@ struct FmhaFwdV3Kernel using Kargs = std::conditional_t; + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -141,6 +174,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -183,6 +217,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, @@ -201,6 +236,10 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); @@ -223,6 +262,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -260,6 +300,7 @@ struct FmhaFwdV3Kernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), @@ -277,6 +318,10 @@ struct FmhaFwdV3Kernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); @@ -594,6 +639,21 @@ struct FmhaFwdV3Kernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, @@ -601,6 +661,9 @@ struct FmhaFwdV3Kernel lse_dram_window, mask, kargs.scale_s, + variant, + variant_params, + block_indices, smem_ptr); }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 68ec349694..c25f57632f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -264,6 +264,7 @@ struct BlockFmhaFwdV3Pipeline using PDataType = ck_tile::remove_cvref_t; using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static_assert(is_generic_attention_mask_v); @@ -298,8 +299,7 @@ struct BlockFmhaFwdV3Pipeline static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; - static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && - !kStoreLSE && !kHasDropout && + static_assert((BiasEnum == BlockAttentionBiasEnum::NO_BIAS && !kStoreLSE && !kHasDropout && (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && !kSkipMinSeqlenQ), "enable unsupported features"); @@ -401,7 +401,9 @@ struct BlockFmhaFwdV3Pipeline typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -415,6 +417,9 @@ struct BlockFmhaFwdV3Pipeline const OAccElementFunction& o_acc_element_func, FmhaMask mask, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr) const { using namespace ck_tile; @@ -721,6 +726,22 @@ struct BlockFmhaFwdV3Pipeline /// TODO: remove the sp_delta and use sp_compute directly statically_indexed_array{}).sp_compute), 2> sp_delta; + auto fmha_logits_trans = [&](auto sp_reg_idx) { + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = [&variant, &variant_params, &block_indices]( + auto& logits) { + logits = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, logits), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; + + tile_elementwise_inout(apply_logits_transform, sp(sp_reg_idx).sp_compute); + } + }; + auto fmha_alu0 = [&](auto sp_reg_idx) { m_old = m; // m{j-1} static_assert(m.thread_buf_.size() == 1, @@ -746,9 +767,17 @@ struct BlockFmhaFwdV3Pipeline std::decay_t::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( - sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(kHasLogitsSoftCap) + { + sp_delta(sp_reg_idx)(i_j_idx) = + sp(sp_reg_idx).sp_compute(i_j_idx) - m(i_j_idx); + } + else + { + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + } }); }); /// TODO: move some fmha_alu1() code here if necessary @@ -793,8 +822,16 @@ struct BlockFmhaFwdV3Pipeline constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); - + const auto tmp = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::exp2(m_old[i_idx] - m[i_idx]); + } + else + { + return ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + } + }(); l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); }); @@ -880,7 +917,16 @@ struct BlockFmhaFwdV3Pipeline }; auto fmha_alu_D_upd = [&] { - o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + o_acc_scale = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::exp2(m_old.thread_buf_[0] - m.thread_buf_[0]); + } + else + { + return ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + } + }(); fp32x2_t pk_o_acc_scale; pk_o_acc_scale.x = o_acc_scale; @@ -928,7 +974,12 @@ struct BlockFmhaFwdV3Pipeline const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = kv_token_start + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -992,6 +1043,7 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); + fmha_logits_trans(xdl_SP_p01_reg_idx); Scheduler::schedule(cl_p, number<0>{}); __builtin_amdgcn_sched_barrier(0); @@ -1066,6 +1118,7 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); + fmha_logits_trans(xdl_SP_p01_reg_idx); Scheduler::schedule(cl_p, number<1>{}); __builtin_amdgcn_sched_barrier(0); @@ -1149,7 +1202,7 @@ struct BlockFmhaFwdV3Pipeline // (3) mfma (Q*K0) + softmax gemm(number<0>{}, /*gemm_idx=*/number<0>{}); - + fmha_logits_trans(number<0>{}); fmha_mask(number<0>{}); /// TODO: find better way to map fmha_alu(0,96) call fmha_alu0(number<0>{}); @@ -1244,13 +1297,18 @@ struct BlockFmhaFwdV3Pipeline template + typename LSEDramBlockWindowTmp, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr) const { using namespace ck_tile; @@ -1268,6 +1326,9 @@ struct BlockFmhaFwdV3Pipeline identity{}, mask, scale_s, + variant, + variant_params, + block_indices, smem_ptr); } }; From db3d5245327bd7a3d419946b2b709645aba6dc02 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 4 Dec 2025 23:14:29 -0600 Subject: [PATCH 48/48] Add missing Kargs base type --- include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 7bc7bcfedc..6fe1de634d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -137,6 +137,7 @@ struct FmhaFwdV3Kernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, + std::conditional_t>, std::conditional_t>, std::conditional_t> {