Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
arakowsk-amd authored Dec 23, 2024
2 parents b335df8 + a264693 commit 3a72550
Show file tree
Hide file tree
Showing 91 changed files with 4,066 additions and 1,704 deletions.
10 changes: 5 additions & 5 deletions Dockerfile.base
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3-complete
ARG HIPBLASLT_BRANCH="507a649"
ARG HIPBLASLT_BRANCH="4d40e36"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="dfe4a3e"
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e192dba"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="8bc4033"
ARG PYTORCH_BRANCH="8d4926e"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="c555642"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"

FROM ${BASE_IMAGE} AS base
Expand Down
8 changes: 4 additions & 4 deletions Dockerfile.base_navi
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3-complete
ARG HIPBLASLT_BRANCH="d43d84a"
ARG HIPBLASLT_BRANCH="4d40e36"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="dfe4a3e"
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e192dba"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="8bc4033"
ARG PYTORCH_BRANCH="8d4926e"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ARG REMOTE_VLLM="0"
ARG USE_CYTHON="0"
ARG BUILD_RPD="1"
ARG COMMON_WORKDIR=/app
ARG BASE_IMAGE=rocm/vllm-dev:base_ubuntu22.04_py3.12_ROCm6.3_hipblaslt0.12_torch2.6
ARG BASE_IMAGE=rocm/vllm-dev:base

FROM ${BASE_IMAGE} AS base

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_rocm_tuning_space(use_fp16):
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
num_stage_range = [2]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []
Expand Down
2 changes: 1 addition & 1 deletion csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#ifdef USE_CUDA_FP8_FORMAT
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ def make_test_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down Expand Up @@ -963,6 +964,7 @@ def make_test_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down
3 changes: 3 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down
7 changes: 3 additions & 4 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)

Expand Down Expand Up @@ -126,8 +126,7 @@ class AttentionMetadata:

# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool = field(init=False,
default_factory=lambda: True)
enable_kv_scales_calculation: bool

@property
@abstractmethod
Expand Down Expand Up @@ -253,6 +252,6 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
raise NotImplementedError
4 changes: 3 additions & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def prefill_metadata(
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down Expand Up @@ -361,7 +363,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down
5 changes: 4 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
Expand Down Expand Up @@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
Expand Down Expand Up @@ -639,7 +642,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
Expand Down Expand Up @@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
Expand Down Expand Up @@ -775,7 +777,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0,
Expand Down Expand Up @@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len,
Expand Down Expand Up @@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down
14 changes: 12 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -202,6 +203,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down Expand Up @@ -549,7 +551,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: torch.Tensor = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -599,6 +601,8 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
None)

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
Expand Down Expand Up @@ -679,6 +683,12 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if (
fp8_out_scale
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
out, _ = self.attn_func(
query,
key,
Expand All @@ -692,7 +702,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)

return attn_metadata
Expand Down Expand Up @@ -433,7 +434,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down Expand Up @@ -326,6 +327,7 @@ def graph_capture_get_metadata_for_batch(
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -261,6 +262,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down Expand Up @@ -418,7 +420,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
Loading

0 comments on commit 3a72550

Please sign in to comment.