@@ -829,14 +829,15 @@ def _attention_with_mask(
829
829
) -> torch .Tensor :
830
830
# Skip writing kv-cache for the initial profiling run.
831
831
if len (kv_cache .shape ) > 1 :
832
+ i = torch .ones (1 , dtype = torch .float32 )
832
833
if current_platform .is_rocm ():
833
834
key_cache , value_cache = PagedAttention .split_kv_cache (
834
835
kv_cache , self .num_local_key_value_heads , self .head_dim )
835
836
cached_k = torch .cat ([k [s :e ] for s , e in kv_range_for_decode ])
836
837
cached_v = torch .cat ([v [s :e ] for s , e in kv_range_for_decode ])
837
838
PagedAttention .write_to_paged_cache (
838
839
cached_k , cached_v , key_cache , value_cache ,
839
- attn_metadata .cross_slot_mapping , "auto" , 1.0 , 1.0 )
840
+ attn_metadata .cross_slot_mapping , "auto" , i , i )
840
841
else :
841
842
if self .attn .backend in (_Backend .FLASH_ATTN ,
842
843
_Backend .FLASH_ATTN_VLLM_V1 ):
@@ -852,8 +853,8 @@ def _attention_with_mask(
852
853
attn_metadata .
853
854
cross_slot_mapping , # type: ignore[union-attr]
854
855
"auto" ,
855
- 1.0 ,
856
- 1.0 ,
856
+ i ,
857
+ i ,
857
858
)
858
859
elif self .attn .backend in (_Backend .XFORMERS ,
859
860
_Backend .TORCH_SDPA ):
@@ -866,7 +867,7 @@ def _attention_with_mask(
866
867
[v [s :e ] for s , e in kv_range_for_decode ])
867
868
PagedAttention .write_to_paged_cache (
868
869
cached_k , cached_v , key_cache , value_cache ,
869
- attn_metadata .cross_slot_mapping , "auto" , 1.0 , 1.0 )
870
+ attn_metadata .cross_slot_mapping , "auto" , i , i )
870
871
else :
871
872
raise ValueError (
872
873
f"Unsupported Attention backend { self .attn .backend } "
0 commit comments