Skip to content

Commit

Permalink
Mllama kv scale fix (#335)
Browse files Browse the repository at this point in the history
* Using tensors in the explicit cache function calls from mllama implementation

* Properly creating the tensor

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
  • Loading branch information
gshtras committed Dec 18, 2024
1 parent b302bfb commit ef181a9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def _attention_with_mask(
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
Expand All @@ -839,8 +840,8 @@ def _attention_with_mask(
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
Expand All @@ -849,7 +850,7 @@ def _attention_with_mask(
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
Expand Down

0 comments on commit ef181a9

Please sign in to comment.