Skip to content

Commit

Permalink
Ingest FP8 attn scales and use them in ROCm FlashAttention (#338)
Browse files Browse the repository at this point in the history
* Ingest FP8 attn scales and use them in Triton FA, if present

* Disabling calc_kv_scales if the checkoint has them. Enabling fp8 attention for dynamic quantization

* q_range as an env

* format

* Dedupe FA/PA attn toggles, set FA off by default

* Lint again, to fixed point

* Don't calculate KV scales dynamically if Q scale is included

---------

Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
  • Loading branch information
mawong-amd and gshtras authored Dec 20, 2024
1 parent d08b78b commit 1dcd9fe
Show file tree
Hide file tree
Showing 20 changed files with 157 additions and 81 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,6 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
raise NotImplementedError
2 changes: 1 addition & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: torch.Tensor = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -601,6 +601,8 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
None)

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
Expand Down Expand Up @@ -681,6 +683,12 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if (
fp8_out_scale
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
out, _ = self.attn_func(
query,
key,
Expand All @@ -694,7 +702,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
19 changes: 11 additions & 8 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Attention layer."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -75,6 +75,8 @@ def __init__(
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
Expand Down Expand Up @@ -106,11 +108,11 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# For cuda and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
self.use_direct_call = not current_platform.is_cuda(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
Expand All @@ -124,6 +126,7 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix

self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

Expand All @@ -135,12 +138,11 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)

self.calc_kv_scales(query, key, value)
if self.use_direct_call:
return self.impl.forward(query,
key,
Expand All @@ -150,7 +152,7 @@ def forward(
self._k_scale,
self._v_scale,
attn_type=attn_type,
fp8_out_scale=fp8_out_scale)
fp8_comp_scales=fp8_comp_scales)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
Expand All @@ -172,7 +174,8 @@ def forward(
kv_cache, attn_type,
self.layer_name)

def calc_kv_scales(self, key, value):
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
# We only calculate the scales once
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def attn_fwd(
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
z = 0.0
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
Expand Down
29 changes: 21 additions & 8 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -83,8 +84,9 @@
VLLM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
Q_SCALE_CONSTANT: int = 20
K_SCALE_CONSTANT: int = 20
V_SCALE_CONSTANT: int = 10


def get_default_cache_root():
Expand Down Expand Up @@ -242,13 +244,18 @@ def get_default_config_root():
# custom paged attention implemented for MI3* cards
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1") != "0"),
("true", "1")),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in
("true", "1") != "0"),
("true", "1")),

# use quantized q,k,v,softmax(qk^T), attn output during prefill
"VLLM_USE_ROCM_FP8_FLASH_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in
("true", "1")),

# rank of the process in the distributed setting, used to determine
# the driver worker
Expand Down Expand Up @@ -530,13 +537,19 @@ def get_default_config_root():
"VLLM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))),

# Divisor for dynamic key scale factor calculation for FP8 KV Cache
# Divisor for dynamic query scale factor calculation for FP8 attention
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")),

# Divisor for dynamic key scale factor calculation
# for FP8 KV Cache and attention
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
lambda: int(os.getenv("K_SCALE_CONSTANT", "20")),

# Divisor for dynamic value scale factor calculation for FP8 KV Cache
# Divisor for dynamic value scale factor calculation
# for FP8 KV Cache and attention
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
lambda: int(os.getenv("V_SCALE_CONSTANT", "10")),

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

Expand Down
Loading

0 comments on commit 1dcd9fe

Please sign in to comment.