Skip to content

Commit 549d741

Browse files
authored
[webgpu] Optimize Attention by enhancing flash attention support (#26715)
This pull request improves the WebGPU BERT attention implementation by enhancing FlashAttention support, generalizing tensor layout handling, and increasing batch size flexibility. The changes focus on supporting both BSNH and BNSH tensor layouts, enabling FlashAttention for multi-batch scenarios, and ensuring correct broadcasting and dispatch sizing for attention bias and batch dimensions. Key improvements include: **FlashAttention Support & Generalization:** * Added support for both BSNH and BNSH tensor layouts by introducing the `q_BNSH` parameter and updating shader code, program classes, and kernel logic to handle either layout correctly. This includes changes in the WGSL template and C++ logic for offset calculations and program instantiation. [[1]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fR7) [[2]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fL45-R97) [[3]](diffhunk://#diff-de9fb56a92586a62185eae0a2e0153f12960bc73dab990e616185236e115885fL86-R122) [[4]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R445) [[5]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R454) [[6]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R76) [[7]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R86) [[8]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R110) * Updated the `CanApplyFlashAttention` and `ApplyFlashAttention` logic to allow multi-batch operation by removing the restriction to batch size 1 and ensuring present key/value tensors are always created for FlashAttention. [[1]](diffhunk://#diff-1ed746fa440247995dabd97ad1f318a548fc385cde70b9ea2d4a410219f91629R740-R752) [[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L501-L506) [[3]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L177-R185) **Batch & Bias Handling:** * Modified dispatch group size calculations and uniform variables throughout the FlashAttention pipeline to properly account for batch size, ensuring correct parallelization for multi-batch scenarios. [[1]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R260-R273) [[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L272-R285) [[3]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L320-R333) [[4]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L366-R379) [[5]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L454-R490) [[6]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R95-R100) [[7]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L123-R131) * Added logic to extract and pass attention bias dimensions as uniforms for correct broadcasting in both the compute and shader code. [[1]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8R260-R273) [[2]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L272-R285) [[3]](diffhunk://#diff-c21dffe27e10565d78827773edf856be89f28b4dfefe1a79d18e083c0b18b0e8L454-R490) [[4]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9R95-R100) [[5]](diffhunk://#diff-27882bdbb4d2adc903ff91fbb7b09feb61c53d7a5d86d9336e294d631b7f59e9L123-R131) **Other Enhancements:** * Improved handling of QKV format detection and generalized code to support more format variants in `CopyKVCache`. * Updated includes and dependencies to ensure all necessary headers for FlashAttention are present. These changes collectively make the WebGPU BERT attention implementation more robust, flexible, and performant across different tensor layouts and batch sizes. phi-4-mm-vision.onnx Before Kernel | Time (ms) | Percentage (%) -- | -- | -- Attention\|AttentionProbs | 159.66 | 11.14 Attention\|VxAttentionScore | 122.56 | 8.55 Attention\|InPlaceSoftmax | 51.83 | 3.62 After Kernel | Time (ms) | Percentage (%) -- | -- | -- Attention\|FlashAttention | 60.23 | 5.38
1 parent 07bf9a0 commit 549d741

File tree

9 files changed

+185
-98
lines changed

9 files changed

+185
-98
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "contrib_ops/webgpu/bert/attention.h"
55

66
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
7+
#include "contrib_ops/webgpu/bert/flash_attention.h"
78
#include "contrib_ops/webgpu/bert/multihead_attention.h"
89
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
910
#include "core/providers/webgpu/webgpu_supported_types.h"
@@ -736,6 +737,19 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context)
736737
// Compute Q, K, V from input, weights, and bias
737738
ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V));
738739

740+
// Check if we can use flash attention
741+
// For Attention operator, we need to create present_key and present_value tensors for flash attention
742+
// even though they are not exposed as outputs
743+
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.num_heads_,
744+
parameters.total_sequence_length_, parameters.head_size_});
745+
Tensor present_key = context.CreateGPUTensor(input->DataType(), present_kv_shape);
746+
Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape);
747+
748+
if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) {
749+
return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value,
750+
parameters, context, nullptr);
751+
}
752+
739753
// Apply the actual attention computation
740754
return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr,
741755
/* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1);

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
7676
} else {
7777
shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n";
7878
}
79-
shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n";
79+
shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n";
80+
if (past_present_share_buffer_) {
81+
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n";
82+
} else {
83+
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n";
84+
}
8085

8186
// Add indirect dispatch logic for thread 0
8287
if (prepare_indirect_dispatch_) {
@@ -93,8 +98,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
9398
if (has_past_) {
9499
const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
95100
shader.AddInput("past_value", ShaderUsage::UseUniform);
96-
shader.MainFunctionBody() << "let present_offset = global_idx;"
97-
<< "if (sequence_id < past_sequence_length) {\n"
101+
shader.MainFunctionBody() << "if (sequence_id < past_sequence_length) {\n"
98102
<< " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n"
99103
<< " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n"
100104
<< " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n"
@@ -104,8 +108,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
104108
<< " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"
105109
<< "}";
106110
} else {
107-
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"
108-
<< " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n"
111+
shader.MainFunctionBody() << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n"
109112
<< " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n"
110113
<< " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n";
111114
}
@@ -134,10 +137,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
134137
// Determine if we need to prepare indirect dispatch
135138
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
136139
bool use_seqlen_k = (seqlen_k != nullptr);
137-
138-
CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH,
140+
bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH;
141+
CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_,
139142
prepare_indirect_dispatch, use_seqlen_k};
140-
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
143+
if (kv_BNSH) {
141144
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
142145
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
143146
} else {
@@ -207,6 +210,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
207210
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
208211
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
209212
WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_),
213+
WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
210214
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
211215
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
212216
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
@@ -256,10 +260,20 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
256260
{metadata, ProgramTensorMetadataDependency::Rank, 2}});
257261

258262
const uint32_t vectorized_head_size = parameters.head_size_ / components;
263+
264+
// Get attention bias dimensions for broadcasting
265+
uint32_t attn_bias_dim0 = 1;
266+
uint32_t attn_bias_dim1 = 1;
267+
if (has_attention_bias) {
268+
const auto& bias_shape = attention_bias->Shape();
269+
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
270+
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
271+
}
272+
259273
if (use_indirect_dispatch) {
260274
program.SetIndirectDispatchTensor(indirect_buffer);
261275
} else {
262-
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
276+
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile);
263277
}
264278
program.SetWorkgroupSize(64)
265279
.CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
@@ -269,7 +283,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
269283
present_sequence_length,
270284
{static_cast<uint32_t>(parameters.n_reps)},
271285
{num_present_sequence_length_tile},
272-
{static_cast<uint32_t>(parameters.num_heads_)}});
286+
{static_cast<uint32_t>(parameters.num_heads_)},
287+
{static_cast<uint32_t>(parameters.batch_size_)},
288+
{attn_bias_dim0},
289+
{attn_bias_dim1}});
273290

274291
return context.RunProgram(program);
275292
}
@@ -313,11 +330,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
313330
{qk, ProgramTensorMetadataDependency::TypeAndRank},
314331
{present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
315332
program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
333+
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
316334
if (use_indirect_dispatch) {
317335
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None})
318336
.SetIndirectDispatchTensor(indirect_buffer);
319337
} else {
320-
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
338+
program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
321339
}
322340
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch)
323341
.SetWorkgroupSize(64)
@@ -326,7 +344,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
326344
present_sequence_length,
327345
{static_cast<uint32_t>(parameters.n_reps)},
328346
num_present_sequence_length_tile,
329-
{static_cast<uint32_t>(parameters.num_heads_)}});
347+
{batch_heads}});
330348

331349
return context.RunProgram(program);
332350
}
@@ -363,14 +381,15 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
363381
}
364382
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
365383
const uint32_t num_head_size_tile = static_cast<uint32_t>((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size);
366-
program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile)
384+
const uint32_t batch_heads = static_cast<uint32_t>(parameters.batch_size_ * parameters.num_heads_);
385+
program.SetDispatchGroupSize(batch_heads * num_head_size_tile)
367386
.CacheHint(tile_size, seq_tile_size, use_indirect_dispatch)
368387
.SetWorkgroupSize(tile_size * tile_size)
369388
.AddUniformVariables({{static_cast<uint32_t>(parameters.v_head_size_ / components)},
370389
num_total_seq_length_tile,
371390
num_present_sequence_length_tile,
372391
{num_head_size_tile},
373-
{static_cast<uint32_t>(parameters.num_heads_)}});
392+
{batch_heads}});
374393

375394
return context.RunProgram(program);
376395
}
@@ -429,6 +448,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
429448
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
430449
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
431450
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
451+
bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
432452
FlashAttentionProgram program{"FlashAttention",
433453
has_attention_bias,
434454
is_qualcomm,
@@ -437,6 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
437457
parameters.num_heads_,
438458
parameters.is_unidirectional_,
439459
is_nvidia,
460+
q_BNSH,
440461
use_seqlen_k};
441462
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
442463
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
@@ -451,15 +472,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
451472
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
452473
: parameters.scale_;
453474
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
454-
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)
475+
476+
// Get attention bias dimensions for broadcasting
477+
uint32_t attn_bias_dim0 = 1;
478+
uint32_t attn_bias_dim1 = 1;
479+
if (has_attention_bias) {
480+
const auto& bias_shape = attention_bias->Shape();
481+
attn_bias_dim0 = static_cast<uint32_t>(bias_shape[0]);
482+
attn_bias_dim1 = static_cast<uint32_t>(bias_shape[1]);
483+
}
484+
485+
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
455486
.SetWorkgroupSize(tile_size)
456-
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k)
487+
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k)
457488
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
458489
{static_cast<uint32_t>(parameters.total_sequence_length_)},
459490
{static_cast<uint32_t>(present_sequence_length)},
491+
{static_cast<uint32_t>(parameters.batch_size_)},
460492
{static_cast<uint32_t>(parameters.n_reps)},
461493
{alpha},
462-
{num_seq_tile}});
494+
{num_seq_tile},
495+
{attn_bias_dim0},
496+
{attn_bias_dim1}});
463497

464498
return context.RunProgram(program);
465499
}
@@ -500,8 +534,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
500534

501535
bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
502536
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
503-
return parameters.batch_size_ == 1 &&
504-
!parameters.is_packed_qkv_ &&
537+
return !parameters.is_packed_qkv_ &&
505538
parameters.head_size_ == parameters.v_head_size_ &&
506539
bias == nullptr &&
507540
context.HasFeature(wgpu::FeatureName::Subgroups) &&

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<S
4343

4444
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
4545
public:
46-
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
46+
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer,
4747
bool prepare_indirect_dispatch = false, bool use_seqlen_k = false)
48-
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) {
48+
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) {
4949
}
5050

5151
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -59,6 +59,7 @@ class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
5959
private:
6060
bool has_past_;
6161
bool kv_BNSH_;
62+
bool past_present_share_buffer_;
6263
bool prepare_indirect_dispatch_;
6364
bool use_seqlen_k_;
6465
};
@@ -73,6 +74,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
7374
int qkv_num_heads,
7475
bool is_unidirectional,
7576
bool is_nvidia,
77+
bool q_BNSH,
7678
bool use_seqlen_k = false)
7779
: Program{kernel_name},
7880
has_attention_bias_(has_attention_bias),
@@ -82,6 +84,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
8284
qkv_num_heads_(qkv_num_heads),
8385
is_unidirectional_(is_unidirectional),
8486
is_nvidia_(is_nvidia),
87+
q_BNSH_(q_BNSH),
8588
use_seqlen_k_(use_seqlen_k) {
8689
}
8790

@@ -90,9 +93,12 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
9093
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
9194
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
9295
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
96+
{"batch_size", ProgramUniformVariableDataType::Uint32},
9397
{"n_reps", ProgramUniformVariableDataType::Uint32},
9498
{"alpha", ProgramUniformVariableDataType::Float32},
95-
{"num_seq_tile", ProgramUniformVariableDataType::Uint32});
99+
{"num_seq_tile", ProgramUniformVariableDataType::Uint32},
100+
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
101+
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});
96102

97103
private:
98104
bool has_attention_bias_;
@@ -102,6 +108,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
102108
int qkv_num_heads_;
103109
bool is_unidirectional_;
104110
bool is_nvidia_;
111+
bool q_BNSH_;
105112
bool use_seqlen_k_;
106113
};
107114

@@ -120,7 +127,10 @@ class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecode
120127
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
121128
{"n_reps", ProgramUniformVariableDataType::Uint32},
122129
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
123-
{"num_heads", ProgramUniformVariableDataType::Uint32});
130+
{"num_heads", ProgramUniformVariableDataType::Uint32},
131+
{"batch_size", ProgramUniformVariableDataType::Uint32},
132+
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
133+
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});
124134

125135
private:
126136
bool has_attention_bias_;
@@ -141,7 +151,7 @@ class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDe
141151
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
142152
{"n_reps", ProgramUniformVariableDataType::Uint32},
143153
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
144-
{"num_heads", ProgramUniformVariableDataType::Uint32});
154+
{"batch_heads", ProgramUniformVariableDataType::Uint32});
145155

146156
private:
147157
uint32_t tile_size_;
@@ -161,7 +171,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD
161171
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
162172
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
163173
{"num_head_size_tile", ProgramUniformVariableDataType::Uint32},
164-
{"num_heads", ProgramUniformVariableDataType::Uint32});
174+
{"batch_heads", ProgramUniformVariableDataType::Uint32});
165175

166176
private:
167177
uint32_t tile_size_;

0 commit comments

Comments
 (0)