Skip to content

Commit 679a15c

Browse files
authored
Fix max_seqlens_q/k initialization for Navi GPUs (#310)
- max_seqlens_q/k variables were not correctly initialized for Navi GPUs leading to incorrect outputs. - ensure that the correct values are passed to the attn_fwd kernel based on the GPU type.
1 parent 44212d7 commit 679a15c

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

vllm/attention/ops/triton_flash_attention.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,8 @@ def check_and_convert(t, scale):
912912
p_descale = 1.0 / p_scale
913913
o_descale = 1.0 / o_scale
914914

915-
if is_navi():
916-
max_seqlens_q = 0
917-
max_seqlens_k = 0
915+
arg_max_seqlens_q = 0 if is_navi() else max_seqlens_q
916+
arg_max_seqlens_k = 0 if is_navi() else max_seqlens_k
918917

919918
attn_fwd[grid](
920919
q,
@@ -944,8 +943,8 @@ def check_and_convert(t, scale):
944943
HQ=nheads_q,
945944
HK=nheads_k,
946945
ACTUAL_BLOCK_DMODEL=head_size,
947-
MAX_SEQLENS_Q=max_seqlens_q,
948-
MAX_SEQLENS_K=max_seqlens_k,
946+
MAX_SEQLENS_Q=arg_max_seqlens_q,
947+
MAX_SEQLENS_K=arg_max_seqlens_k,
949948
IS_CAUSAL=causal,
950949
VARLEN=True,
951950
BLOCK_DMODEL=padded_d_model,

0 commit comments

Comments
 (0)