File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -607,18 +607,17 @@ def forward(
607
607
assert attn_metadata .num_encoder_tokens is not None
608
608
num_prefill_tokens = attn_metadata .num_encoder_tokens
609
609
610
+ output = torch .empty_like (query )
610
611
# Query for decode. KV is not needed because it is already cached.
611
612
decode_query = query [num_prefill_tokens :]
612
-
613
613
# QKV for prefill.
614
614
query = query [:num_prefill_tokens ]
615
+
615
616
if key is not None and value is not None :
616
617
key = key [:num_prefill_tokens ]
617
618
value = value [:num_prefill_tokens ]
618
619
619
620
if prefill_meta := attn_metadata .prefill_metadata :
620
- output = torch .empty_like (query )
621
-
622
621
# Prompt run.
623
622
# normal attention and DECODER
624
623
if attn_type == AttentionType .DECODER and (
@@ -735,7 +734,6 @@ def forward(
735
734
if decode_meta := attn_metadata .decode_metadata :
736
735
# Decoding run.
737
736
# Whether to use rocm custom paged attention or not
738
- output = torch .empty_like (decode_query )
739
737
num_seqs , num_heads , head_size = decode_query .shape
740
738
block_size = value_cache .shape [3 ]
741
739
gqa_ratio = num_heads // self .num_kv_heads
You can’t perform that action at this time.
0 commit comments