Skip to content

Commit a9e739c

Browse files
committed
Add logits soft-capping support for fmha fwd v3 pipeline (WIP)
1 parent 5e1f431 commit a9e739c

File tree

4 files changed

+156
-31
lines changed

4 files changed

+156
-31
lines changed

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,10 @@
201201
const bool can_dispatch_v3 =
202202
(device_name.compare(0, 6, "gfx950") == 0) and
203203
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
204-
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
205-
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
206-
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
207-
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
208-
(args.hdim_v == 128);
204+
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
205+
(not traits.has_lse) and (not traits.has_dropout) and
206+
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
207+
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
209208
if ({F_is_v3_enabled} and can_dispatch_v3) {{
210209
return fmha_fwd_v3(traits, args, config);
211210
}} else {{
@@ -1054,9 +1053,9 @@ def get_pipelines(
10541053
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
10551054
if (hdim, hdim_v) == (128, 128):
10561055
# qr_async_trload_v3 only supports (generic) causal mask
1057-
for mask in ["no", "causal"]:
1056+
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
10581057
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
1059-
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip
1058+
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip
10601059

10611060
return pipelines
10621061

@@ -1337,8 +1336,8 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool:
13371336
FMHA_FWD_API_FOOTER_TEMPLATE.format(
13381337
F_is_v3_enabled=BOOL_MAP[
13391338
# NOTE: enable v3 pipelines when ready
1340-
# 0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
1341-
False
1339+
0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
1340+
# False
13421341
]
13431342
),
13441343
]

example/ck_tile/01_fmha/fmha_fwd.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
722722
args.nhead_q,
723723
args.nhead_q / args.nhead_k,
724724
args.scale_s,
725+
args.logits_soft_cap,
725726
args.stride_q,
726727
args.stride_k,
727728
args.stride_v,
@@ -752,6 +753,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
752753
args.nhead_q,
753754
args.nhead_q / args.nhead_k,
754755
args.scale_s,
756+
args.logits_soft_cap,
755757
args.stride_q,
756758
args.stride_k,
757759
args.stride_v,

include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ck_tile/core.hpp"
77
#include "ck_tile/ops/common.hpp"
88
#include "ck_tile/ops/fmha/block/block_masking.hpp"
9+
#include "ck_tile/ops/fmha/block/variants.hpp"
910

1011
#include <type_traits>
1112
#include <utility>
@@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel
3031
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
3132
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
3233

33-
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
34-
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
35-
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
36-
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
37-
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
38-
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
34+
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
35+
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
36+
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
37+
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
38+
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
39+
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
40+
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
3941

40-
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
42+
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
43+
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
4144
static constexpr bool kHasMask = FmhaMask::IsMasking;
4245

4346
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel
9396
ck_tile::index_t batch_stride_lse = 0;
9497
};
9598

99+
struct FmhaFwdLogitsSoftCapKargs
100+
{
101+
FmhaFwdLogitsSoftCapKargs() = default;
102+
103+
void init_logits_soft_cap(float logits_soft_cap_)
104+
{
105+
if(0 < logits_soft_cap_)
106+
{
107+
logits_soft_cap = logits_soft_cap_;
108+
logits_soft_cap_rcp = 1.f / logits_soft_cap;
109+
}
110+
else
111+
{
112+
logits_soft_cap = 0.f;
113+
logits_soft_cap_rcp = 0.f;
114+
}
115+
}
116+
117+
float logits_soft_cap;
118+
float logits_soft_cap_rcp;
119+
};
120+
96121
struct FmhaFwdBatchModeKargs
97122
: FmhaFwdCommonKargs,
98123
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
99-
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
124+
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
125+
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
100126
{
101127
ck_tile::index_t batch_stride_q;
102128
ck_tile::index_t batch_stride_k;
@@ -111,8 +137,8 @@ struct FmhaFwdV3Kernel
111137

112138
struct FmhaFwdGroupModeKargs
113139
: FmhaFwdCommonKargs,
114-
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
115-
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
140+
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
141+
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
116142
{
117143
const int32_t* seqstart_q_ptr;
118144
const int32_t* seqstart_k_ptr;
@@ -127,6 +153,13 @@ struct FmhaFwdV3Kernel
127153

128154
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
129155

156+
struct BlockIndices
157+
{
158+
ck_tile::index_t batch_idx;
159+
ck_tile::index_t qo_head_idx;
160+
ck_tile::index_t kv_head_idx;
161+
};
162+
130163
template <bool Cond = !kIsGroupMode>
131164
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
132165
MakeKargs(const void* q_ptr,
@@ -141,6 +174,7 @@ struct FmhaFwdV3Kernel
141174
ck_tile::index_t num_head_q,
142175
ck_tile::index_t nhead_ratio_qk,
143176
float scale_s,
177+
float logits_soft_cap,
144178
ck_tile::index_t stride_q,
145179
ck_tile::index_t stride_k,
146180
ck_tile::index_t stride_v,
@@ -183,6 +217,7 @@ struct FmhaFwdV3Kernel
183217
nhead_stride_o}, // args for common karg
184218
{}, // placeholder for mask
185219
{}, // placeholder for lse
220+
{}, // placeholder for logits_soft_cap
186221
batch_stride_q,
187222
batch_stride_k,
188223
batch_stride_v,
@@ -201,6 +236,10 @@ struct FmhaFwdV3Kernel
201236
kargs.nhead_stride_lse = nhead_stride_lse;
202237
kargs.batch_stride_lse = batch_stride_lse;
203238
}
239+
if constexpr(kHasLogitsSoftCap)
240+
{
241+
kargs.init_logits_soft_cap(logits_soft_cap);
242+
}
204243

205244
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
206245
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
@@ -223,6 +262,7 @@ struct FmhaFwdV3Kernel
223262
ck_tile::index_t num_head_q,
224263
ck_tile::index_t nhead_ratio_qk,
225264
float scale_s,
265+
float logits_soft_cap,
226266
ck_tile::index_t stride_q,
227267
ck_tile::index_t stride_k,
228268
ck_tile::index_t stride_v,
@@ -260,6 +300,7 @@ struct FmhaFwdV3Kernel
260300
nhead_stride_o}, // args for common karg
261301
{}, // placeholder for mask
262302
{}, // placeholder for lse
303+
{}, // placeholder for logits_soft_cap
263304
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
264305
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
265306
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
@@ -277,6 +318,10 @@ struct FmhaFwdV3Kernel
277318
kargs.lse_ptr = lse_ptr;
278319
kargs.nhead_stride_lse = nhead_stride_lse;
279320
}
321+
if constexpr(kHasLogitsSoftCap)
322+
{
323+
kargs.init_logits_soft_cap(logits_soft_cap);
324+
}
280325

281326
kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
282327
kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
@@ -594,13 +639,31 @@ struct FmhaFwdV3Kernel
594639
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
595640
}();
596641

642+
AttentionVariant variant;
643+
const auto variant_params = [&] {
644+
if constexpr(kHasLogitsSoftCap)
645+
{
646+
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
647+
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
648+
}
649+
else
650+
{
651+
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
652+
}
653+
}();
654+
655+
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
656+
597657
auto o_acc_tile = [&]() {
598658
return FmhaPipeline{}(q_dram_window,
599659
k_dram_window,
600660
v_dram_window,
601661
lse_dram_window,
602662
mask,
603663
kargs.scale_s,
664+
variant,
665+
variant_params,
666+
block_indices,
604667
smem_ptr);
605668
}();
606669

0 commit comments

Comments
 (0)