Skip to content

Commit fa1ff83

Browse files
authored
Mllama kv scale fix (#335)
* Using tensors in the explicit cache function calls from mllama implementation * Properly creating the tensor
1 parent 27f53a2 commit fa1ff83

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

vllm/model_executor/models/mllama.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -829,14 +829,15 @@ def _attention_with_mask(
829829
) -> torch.Tensor:
830830
# Skip writing kv-cache for the initial profiling run.
831831
if len(kv_cache.shape) > 1:
832+
i = torch.ones(1, dtype=torch.float32)
832833
if current_platform.is_rocm():
833834
key_cache, value_cache = PagedAttention.split_kv_cache(
834835
kv_cache, self.num_local_key_value_heads, self.head_dim)
835836
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
836837
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
837838
PagedAttention.write_to_paged_cache(
838839
cached_k, cached_v, key_cache, value_cache,
839-
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
840+
attn_metadata.cross_slot_mapping, "auto", i, i)
840841
else:
841842
if self.attn.backend in (_Backend.FLASH_ATTN,
842843
_Backend.FLASH_ATTN_VLLM_V1):
@@ -852,8 +853,8 @@ def _attention_with_mask(
852853
attn_metadata.
853854
cross_slot_mapping, # type: ignore[union-attr]
854855
"auto",
855-
1.0,
856-
1.0,
856+
i,
857+
i,
857858
)
858859
elif self.attn.backend in (_Backend.XFORMERS,
859860
_Backend.TORCH_SDPA):
@@ -866,7 +867,7 @@ def _attention_with_mask(
866867
[v[s:e] for s, e in kv_range_for_decode])
867868
PagedAttention.write_to_paged_cache(
868869
cached_k, cached_v, key_cache, value_cache,
869-
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
870+
attn_metadata.cross_slot_mapping, "auto", i, i)
870871
else:
871872
raise ValueError(
872873
f"Unsupported Attention backend {self.attn.backend} "

0 commit comments

Comments
 (0)