Skip to content

Commit be448fb

Browse files
authored
Merge branch 'main' into upstream_merge_24_10_21
2 parents 87e3970 + 69d5e1d commit be448fb

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,18 +607,17 @@ def forward(
607607
assert attn_metadata.num_encoder_tokens is not None
608608
num_prefill_tokens = attn_metadata.num_encoder_tokens
609609

610+
output = torch.empty_like(query)
610611
# Query for decode. KV is not needed because it is already cached.
611612
decode_query = query[num_prefill_tokens:]
612-
613613
# QKV for prefill.
614614
query = query[:num_prefill_tokens]
615+
615616
if key is not None and value is not None:
616617
key = key[:num_prefill_tokens]
617618
value = value[:num_prefill_tokens]
618619

619620
if prefill_meta := attn_metadata.prefill_metadata:
620-
output = torch.empty_like(query)
621-
622621
# Prompt run.
623622
# normal attention and DECODER
624623
if attn_type == AttentionType.DECODER and (
@@ -735,7 +734,6 @@ def forward(
735734
if decode_meta := attn_metadata.decode_metadata:
736735
# Decoding run.
737736
# Whether to use rocm custom paged attention or not
738-
output = torch.empty_like(decode_query)
739737
num_seqs, num_heads, head_size = decode_query.shape
740738
block_size = value_cache.shape[3]
741739
gqa_ratio = num_heads // self.num_kv_heads

0 commit comments

Comments
 (0)