Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
57a47b0
Let fmha_fwd_v3() compatible with fmha_fwd()
poyenc Nov 3, 2025
b93a0ad
Decouple get_fwd_blobs() and FmhaFwdKernel
poyenc Nov 3, 2025
6e46366
Decouple compatibility checks from get_fwd_blobs()
poyenc Nov 4, 2025
756a1b8
Extract product feature checks out from get_fwd_blobs()
poyenc Nov 4, 2025
4c5a68e
Remove duplicated code in factories and redundant checks
poyenc Nov 4, 2025
41cd25b
Remove FmhaFwdKernel<>::GetName()
poyenc Nov 5, 2025
3e0ad2c
Let FmhaFwdApiPool support pipelines with different mask_impl
poyenc Nov 5, 2025
4e6153b
Add tile setting for fmha fwd v3 pipeline
poyenc Nov 5, 2025
6eaa880
Add fwd v3 instances to tile_example_fmha_fwd manually
poyenc Nov 5, 2025
d6a99c2
Remove unused function import
poyenc Nov 7, 2025
76b2bc0
Undo irrelevant changes
poyenc Nov 7, 2025
260908a
Remove fwd v3 instances from tile_example_fmha_fwd
poyenc Nov 7, 2025
286a24b
Finish fmha fwd v3 kernel instance codegen
poyenc Nov 7, 2025
006692f
Fix formatting
poyenc Nov 10, 2025
051a6be
Remove unused F_idx attribute
poyenc Nov 10, 2025
0b15146
Add is_generic_attention_mask<> traits
poyenc Nov 10, 2025
a176996
Add constraints to the fmha fwd v3 pipeline
poyenc Nov 10, 2025
10ecccc
Unify traits & problem used for fmha fwd v3
poyenc Nov 10, 2025
16d4573
Unify kernel launch code for fmha fwd v2 & v3
poyenc Nov 10, 2025
1810d6f
Unify kernel template selection logic
poyenc Nov 11, 2025
05ffeac
Use same kernel codegen template for both v2 & v3
poyenc Nov 11, 2025
7b9b7ee
Rename api() property as render() method
poyenc Nov 11, 2025
923a97a
Allow specifying filter for fmha fwd api pool
poyenc Nov 11, 2025
be4d123
Allow specifying function name when rendering api pool items
poyenc Nov 11, 2025
b66d3f5
Separate fmha fwd v3 kernel dispatching logic from v2
poyenc Nov 11, 2025
48487b5
Remove lambda assignment
poyenc Nov 11, 2025
fd8312c
Add simple v2/v3 dispatch logic
poyenc Nov 11, 2025
0a3cfe1
Stop generating empty if-clauses
poyenc Nov 11, 2025
9da8cbb
Use "".join() to concatenate fmha fwd api string content
poyenc Nov 11, 2025
6793877
Add more feature checks for fmha fwd v3 pipeline
poyenc Nov 12, 2025
772c30f
Check features before dispatch to fmha_fwd_v3()
poyenc Nov 12, 2025
eebe510
Add more feature checks for fmha_fwd_v3()
poyenc Nov 12, 2025
1730875
Add missing filter call
poyenc Nov 12, 2025
a62afee
Use Tuple to reserve the dtype orders
poyenc Nov 12, 2025
9c89220
Fix wrong pipeline matching logic
poyenc Nov 12, 2025
23c0022
Add fmha fwd v3 group mode instances
poyenc Nov 13, 2025
6526b59
Add functor_transform<>
poyenc Nov 13, 2025
291cea6
Add type constraints to make_tile_window()
poyenc Nov 13, 2025
f4d92f1
Remove fmha fwd v3 example
poyenc Nov 13, 2025
2df5019
Fix wrong product(aiter mha_fwd()) config
poyenc Nov 13, 2025
66a874a
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 14, 2025
1df098d
Fix wrong fmha fwd v2/v3 selection logic
poyenc Nov 16, 2025
0d0a25b
Merge branch 'poyenc/integrate-fmha-fwd-v2-v3-apis' of github.com:poy…
poyenc Nov 16, 2025
68fd415
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 16, 2025
8e0d9dd
Fix formatting
poyenc Nov 17, 2025
51c30ba
Merge branch 'poyenc/integrate-fmha-fwd-v2-v3-apis' of github.com:poy…
poyenc Nov 17, 2025
d33691d
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 18, 2025
72ad9d7
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 19, 2025
d0730ba
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 20, 2025
13aee99
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
illsilin Nov 20, 2025
7ba44fd
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 21, 2025
9c5364d
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 21, 2025
615e4b8
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 22, 2025
4464745
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 23, 2025
f8ae943
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Dec 3, 2025
cf1f135
Add comment to warning v3 kernel users
poyenc Dec 3, 2025
608a253
Fix wrong codegen logics
poyenc Dec 3, 2025
02ed663
Remove unnecessary param
poyenc Dec 3, 2025
0e29033
Fix format
poyenc Dec 3, 2025
5e1f431
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Dec 4, 2025
1624819
Add logits soft-capping support for fmha fwd v3 pipeline (WIP)
poyenc Dec 4, 2025
9dae100
Merge branch 'develop' into poyenc/fa-v3-logits-capping
poyenc Dec 4, 2025
8b61405
Merge branch 'develop' into poyenc/fa-v3-logits-capping
poyenc Dec 5, 2025
db3d524
Add missing Kargs base type
poyenc Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,10 @@
const bool can_dispatch_v3 =
(device_name.compare(0, 6, "gfx950") == 0) and
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
(args.hdim_v == 128);
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
(not traits.has_lse) and (not traits.has_dropout) and
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
if ({F_is_v3_enabled} and can_dispatch_v3) {{
return fmha_fwd_v3(traits, args, config);
}} else {{
Expand Down Expand Up @@ -1054,9 +1053,9 @@ def get_pipelines(
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
if (hdim, hdim_v) == (128, 128):
# qr_async_trload_v3 only supports (generic) causal mask
for mask in ["no", "causal"]:
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip

return pipelines

Expand Down
2 changes: 2 additions & 0 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
Expand Down Expand Up @@ -752,6 +753,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
Expand Down
82 changes: 73 additions & 9 deletions include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"

#include <type_traits>
#include <utility>
Expand All @@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;

static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;

using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;

template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
Expand Down Expand Up @@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel
ck_tile::index_t batch_stride_lse = 0;
};

struct FmhaFwdLogitsSoftCapKargs
{
FmhaFwdLogitsSoftCapKargs() = default;

void init_logits_soft_cap(float logits_soft_cap_)
{
if(0 < logits_soft_cap_)
{
logits_soft_cap = logits_soft_cap_;
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap = 0.f;
logits_soft_cap_rcp = 0.f;
}
}

float logits_soft_cap;
float logits_soft_cap_rcp;
};

struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
Expand All @@ -112,7 +138,8 @@ struct FmhaFwdV3Kernel
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
Expand All @@ -127,6 +154,13 @@ struct FmhaFwdV3Kernel

using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;

struct BlockIndices
{
ck_tile::index_t batch_idx;
ck_tile::index_t qo_head_idx;
ck_tile::index_t kv_head_idx;
};

template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
Expand All @@ -141,6 +175,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
Expand Down Expand Up @@ -183,6 +218,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
batch_stride_v,
Expand All @@ -201,6 +237,10 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}

kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
Expand All @@ -223,6 +263,7 @@ struct FmhaFwdV3Kernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
Expand Down Expand Up @@ -260,6 +301,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
Expand All @@ -277,6 +319,10 @@ struct FmhaFwdV3Kernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
}

kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
Expand Down Expand Up @@ -594,13 +640,31 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();

AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
}
}();

BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};

auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
}();

Expand Down
Loading