66#include " ck_tile/core.hpp"
77#include " ck_tile/ops/common.hpp"
88#include " ck_tile/ops/fmha/block/block_masking.hpp"
9+ #include " ck_tile/ops/fmha/block/variants.hpp"
910
1011#include < type_traits>
1112#include < utility>
@@ -30,14 +31,16 @@ struct FmhaFwdV3Kernel
3031 using ODataType = ck_tile::remove_cvref_t <typename FmhaPipeline::ODataType>;
3132 using SaccDataType = ck_tile::remove_cvref_t <typename FmhaPipeline::SaccDataType>;
3233
33- static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode ;
34- static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ ;
35- static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK ;
36- static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ ;
37- static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV ;
38- static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE ;
34+ static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode ;
35+ static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ ;
36+ static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK ;
37+ static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ ;
38+ static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV ;
39+ static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap ;
40+ static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE ;
3941
40- using FmhaMask = ck_tile::remove_cvref_t <typename FmhaPipeline::FmhaMask>;
42+ using AttentionVariant = ck_tile::remove_cvref_t <typename FmhaPipeline::AttentionVariant>;
43+ using FmhaMask = ck_tile::remove_cvref_t <typename FmhaPipeline::FmhaMask>;
4144 static constexpr bool kHasMask = FmhaMask::IsMasking;
4245
4346 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -93,10 +96,33 @@ struct FmhaFwdV3Kernel
9396 ck_tile::index_t batch_stride_lse = 0 ;
9497 };
9598
99+ struct FmhaFwdLogitsSoftCapKargs
100+ {
101+ FmhaFwdLogitsSoftCapKargs () = default ;
102+
103+ void init_logits_soft_cap (float logits_soft_cap_)
104+ {
105+ if (0 < logits_soft_cap_)
106+ {
107+ logits_soft_cap = logits_soft_cap_;
108+ logits_soft_cap_rcp = 1 .f / logits_soft_cap;
109+ }
110+ else
111+ {
112+ logits_soft_cap = 0 .f ;
113+ logits_soft_cap_rcp = 0 .f ;
114+ }
115+ }
116+
117+ float logits_soft_cap;
118+ float logits_soft_cap_rcp;
119+ };
120+
96121 struct FmhaFwdBatchModeKargs
97122 : FmhaFwdCommonKargs,
98123 std::conditional_t <kHasMask , FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0 >>,
99- std::conditional_t <kStoreLSE , FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1 >>
124+ std::conditional_t <kStoreLSE , FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1 >>,
125+ std::conditional_t <kHasLogitsSoftCap , FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2 >>
100126 {
101127 ck_tile::index_t batch_stride_q;
102128 ck_tile::index_t batch_stride_k;
@@ -111,8 +137,8 @@ struct FmhaFwdV3Kernel
111137
112138 struct FmhaFwdGroupModeKargs
113139 : FmhaFwdCommonKargs,
114- std::conditional_t <kHasMask , FmhaFwdMaskKargs , FmhaFwdEmptyKargs<0 >>,
115- std::conditional_t <kStoreLSE , FmhaFwdCommonLSEKargs , FmhaFwdEmptyKargs<1 >>
140+ std::conditional_t <kStoreLSE , FmhaFwdCommonLSEKargs , FmhaFwdEmptyKargs<1 >>,
141+ std::conditional_t <kHasLogitsSoftCap , FmhaFwdLogitsSoftCapKargs , FmhaFwdEmptyKargs<2 >>
116142 {
117143 const int32_t * seqstart_q_ptr;
118144 const int32_t * seqstart_k_ptr;
@@ -127,6 +153,13 @@ struct FmhaFwdV3Kernel
127153
128154 using Kargs = std::conditional_t <kIsGroupMode , FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
129155
156+ struct BlockIndices
157+ {
158+ ck_tile::index_t batch_idx;
159+ ck_tile::index_t qo_head_idx;
160+ ck_tile::index_t kv_head_idx;
161+ };
162+
130163 template <bool Cond = !kIsGroupMode >
131164 CK_TILE_HOST static constexpr std::enable_if_t <Cond, Kargs>
132165 MakeKargs (const void * q_ptr,
@@ -141,6 +174,7 @@ struct FmhaFwdV3Kernel
141174 ck_tile::index_t num_head_q,
142175 ck_tile::index_t nhead_ratio_qk,
143176 float scale_s,
177+ float logits_soft_cap,
144178 ck_tile::index_t stride_q,
145179 ck_tile::index_t stride_k,
146180 ck_tile::index_t stride_v,
@@ -183,6 +217,7 @@ struct FmhaFwdV3Kernel
183217 nhead_stride_o}, // args for common karg
184218 {}, // placeholder for mask
185219 {}, // placeholder for lse
220+ {}, // placeholder for logits_soft_cap
186221 batch_stride_q,
187222 batch_stride_k,
188223 batch_stride_v,
@@ -201,6 +236,10 @@ struct FmhaFwdV3Kernel
201236 kargs.nhead_stride_lse = nhead_stride_lse;
202237 kargs.batch_stride_lse = batch_stride_lse;
203238 }
239+ if constexpr (kHasLogitsSoftCap )
240+ {
241+ kargs.init_logits_soft_cap (logits_soft_cap);
242+ }
204243
205244 kargs.cu_seqlen_q_ptr = reinterpret_cast <const int32_t *>(cu_seqlen_q_ptr);
206245 kargs.cu_seqlen_k_ptr = reinterpret_cast <const int32_t *>(cu_seqlen_k_ptr);
@@ -223,6 +262,7 @@ struct FmhaFwdV3Kernel
223262 ck_tile::index_t num_head_q,
224263 ck_tile::index_t nhead_ratio_qk,
225264 float scale_s,
265+ float logits_soft_cap,
226266 ck_tile::index_t stride_q,
227267 ck_tile::index_t stride_k,
228268 ck_tile::index_t stride_v,
@@ -260,6 +300,7 @@ struct FmhaFwdV3Kernel
260300 nhead_stride_o}, // args for common karg
261301 {}, // placeholder for mask
262302 {}, // placeholder for lse
303+ {}, // placeholder for logits_soft_cap
263304 reinterpret_cast <const int32_t *>(seqstart_q_ptr),
264305 reinterpret_cast <const int32_t *>(seqstart_k_ptr),
265306 reinterpret_cast <const int32_t *>(seqlen_q_ptr),
@@ -277,6 +318,10 @@ struct FmhaFwdV3Kernel
277318 kargs.lse_ptr = lse_ptr;
278319 kargs.nhead_stride_lse = nhead_stride_lse;
279320 }
321+ if constexpr (kHasLogitsSoftCap )
322+ {
323+ kargs.init_logits_soft_cap (logits_soft_cap);
324+ }
280325
281326 kargs.cu_seqlen_q_ptr = reinterpret_cast <const int32_t *>(cu_seqlen_q_ptr);
282327 kargs.cu_seqlen_k_ptr = reinterpret_cast <const int32_t *>(cu_seqlen_k_ptr);
@@ -594,13 +639,31 @@ struct FmhaFwdV3Kernel
594639 return FmhaMask{kargs.seqlen_q , kargs.seqlen_k };
595640 }();
596641
642+ AttentionVariant variant;
643+ const auto variant_params = [&] {
644+ if constexpr (kHasLogitsSoftCap )
645+ {
646+ return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
647+ mask, kargs.scale_s , kargs.logits_soft_cap , kargs.logits_soft_cap_rcp };
648+ }
649+ else
650+ {
651+ return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s };
652+ }
653+ }();
654+
655+ BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk };
656+
597657 auto o_acc_tile = [&]() {
598658 return FmhaPipeline{}(q_dram_window,
599659 k_dram_window,
600660 v_dram_window,
601661 lse_dram_window,
602662 mask,
603663 kargs.scale_s ,
664+ variant,
665+ variant_params,
666+ block_indices,
604667 smem_ptr);
605668 }();
606669
0 commit comments