Skip to content

Commit 70142eb

Browse files
authored
Merge branch 'main' into library_versions_bump
2 parents dfb1df5 + 1dcd9fe commit 70142eb

File tree

20 files changed

+157
-81
lines changed

20 files changed

+157
-81
lines changed

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,6 @@ def forward(
252252
v_scale: torch.Tensor,
253253
attn_type: str = AttentionType.DECODER,
254254
output: Optional[torch.Tensor] = None,
255-
fp8_out_scale: Optional[torch.Tensor] = None,
255+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
256256
) -> torch.Tensor:
257257
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def forward(
363363
v_scale: torch.Tensor,
364364
attn_type: str = AttentionType.DECODER,
365365
output: Optional[torch.Tensor] = None,
366-
fp8_out_scale: Optional[torch.Tensor] = None,
366+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
367367
) -> torch.Tensor:
368368
"""Forward pass with FlashAttention and PagedAttention.
369369

vllm/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def forward(
642642
v_scale: float = 1.0,
643643
attn_type: str = AttentionType.DECODER,
644644
output: Optional[torch.Tensor] = None,
645-
fp8_out_scale: Optional[torch.Tensor] = None,
645+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
646646
) -> torch.Tensor:
647647
"""Forward pass with FlashAttention.
648648

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def forward(
777777
v_scale: float = 1.0,
778778
attn_type: str = AttentionType.DECODER,
779779
output: Optional[torch.Tensor] = None,
780-
fp8_out_scale: Optional[torch.Tensor] = None,
780+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
781781
) -> torch.Tensor:
782782

783783
# TODO: directly write to output tensor

vllm/attention/backends/hpu_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def forward(
154154
v_scale: float = 1.0,
155155
attn_type: str = AttentionType.DECODER,
156156
output: Optional[torch.Tensor] = None,
157-
fp8_out_scale: Optional[torch.Tensor] = None,
157+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
158158
) -> torch.Tensor:
159159
"""Forward pass with xFormers and PagedAttention.
160160

vllm/attention/backends/ipex_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def forward(
174174
v_scale: float = 1.0,
175175
attn_type: str = AttentionType.DECODER,
176176
output: Optional[torch.Tensor] = None,
177-
fp8_out_scale: Optional[torch.Tensor] = None,
177+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
178178
) -> torch.Tensor:
179179
"""Forward pass with IPEX varlen_attention and PagedAttention.
180180

vllm/attention/backends/pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def forward(
152152
v_scale: float = 1.0,
153153
attn_type: str = AttentionType.DECODER,
154154
output: Optional[torch.Tensor] = None,
155-
fp8_out_scale: Optional[torch.Tensor] = None,
155+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
156156
) -> torch.Tensor:
157157
"""Forward pass with Pallas attention.
158158

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def forward(
551551
v_scale: torch.Tensor,
552552
attn_type: str = AttentionType.DECODER,
553553
output: Optional[torch.Tensor] = None,
554-
fp8_out_scale: torch.Tensor = None,
554+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
555555
) -> torch.Tensor:
556556
"""Forward pass with FlashAttention and PagedAttention.
557557
@@ -601,6 +601,8 @@ def forward(
601601
Returns:
602602
shape = [num_tokens, num_heads * head_size]
603603
"""
604+
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
605+
None)
604606

605607
query = query.view(-1, self.num_heads, self.head_size)
606608
if key is not None:
@@ -681,6 +683,12 @@ def forward(
681683
query.dtype,
682684
seq_lens,
683685
make_attn_mask=False) # type: ignore
686+
full_scales = (
687+
1.0 / q_scale.item(), 1.0 / k_scale.item(),
688+
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
689+
fp8_out_scale.item()) if (
690+
fp8_out_scale
691+
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
684692
out, _ = self.attn_func(
685693
query,
686694
key,
@@ -694,7 +702,7 @@ def forward(
694702
self.scale,
695703
attn_masks[0][None]
696704
if attn_masks is not None else None,
697-
None,
705+
full_scales,
698706
)
699707
elif self.use_naive_attn:
700708
if self.num_kv_heads != self.num_heads:

vllm/attention/backends/torch_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def forward(
434434
v_scale: float = 1.0,
435435
attn_type: str = AttentionType.DECODER,
436436
output: Optional[torch.Tensor] = None,
437-
fp8_out_scale: Optional[torch.Tensor] = None,
437+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
438438
) -> torch.Tensor:
439439
"""Forward pass with torch SDPA and PagedAttention.
440440

vllm/attention/backends/xformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def forward(
420420
v_scale: float = 1.0,
421421
attn_type: str = AttentionType.DECODER,
422422
output: Optional[torch.Tensor] = None,
423-
fp8_out_scale: Optional[torch.Tensor] = None,
423+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
424424
) -> torch.Tensor:
425425
"""Forward pass with xFormers and PagedAttention.
426426

vllm/attention/layer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Attention layer."""
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
import torch
55
import torch.nn as nn
@@ -75,6 +75,8 @@ def __init__(
7575
self.calculate_kv_scales = calculate_kv_scales
7676
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
7777
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
78+
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
79+
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
7880
quant_method = quant_config.get_quant_method(
7981
self, prefix=prefix) if quant_config else None
8082
if quant_method is not None:
@@ -106,11 +108,11 @@ def __init__(
106108
self.num_kv_heads = num_kv_heads
107109
self.backend = backend_name_to_enum(attn_backend.get_name())
108110

109-
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
111+
# For cuda and cpu platforms, we control how
110112
# torch.compile works by registering the attention as one giant
111113
# opaque custom op. For other platforms, we directly call them
112114
# and let torch.compile handle them.
113-
self.use_direct_call = not current_platform.is_cuda_alike(
115+
self.use_direct_call = not current_platform.is_cuda(
114116
) and not current_platform.is_cpu()
115117

116118
# For some attention backends, we allocate an output tensor before
@@ -124,6 +126,7 @@ def __init__(
124126
compilation_config.static_forward_context[prefix] = self
125127
self.layer_name = prefix
126128

129+
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
127130
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
128131
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
129132

@@ -135,12 +138,11 @@ def forward(
135138
kv_cache: torch.Tensor,
136139
attn_metadata: AttentionMetadata,
137140
attn_type: str = AttentionType.DECODER,
138-
fp8_out_scale: Optional[torch.Tensor] = None,
141+
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
139142
) -> torch.Tensor:
140143
if self.calculate_kv_scales and \
141144
attn_metadata.enable_kv_scales_calculation:
142-
self.calc_kv_scales(key, value)
143-
145+
self.calc_kv_scales(query, key, value)
144146
if self.use_direct_call:
145147
return self.impl.forward(query,
146148
key,
@@ -150,7 +152,7 @@ def forward(
150152
self._k_scale,
151153
self._v_scale,
152154
attn_type=attn_type,
153-
fp8_out_scale=fp8_out_scale)
155+
fp8_comp_scales=fp8_comp_scales)
154156
elif self.use_output:
155157
output = torch.empty_like(query)
156158
hidden_size = query.size(-1)
@@ -172,7 +174,8 @@ def forward(
172174
kv_cache, attn_type,
173175
self.layer_name)
174176

175-
def calc_kv_scales(self, key, value):
177+
def calc_kv_scales(self, query, key, value):
178+
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
176179
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
177180
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
178181
# We only calculate the scales once

vllm/attention/ops/triton_flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def attn_fwd(
742742
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
743743
out_ptrs_mask = (mask_m_offsets[:, None] >=
744744
out_mask_boundary[None, :])
745-
z = 0.0
745+
z = tl.zeros((1, ), tl.float32)
746746
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
747747
# write back LSE
748748
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m

vllm/envs.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
1818
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
1919
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
20+
VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False
2021
RANK: int = 0
2122
LOCAL_RANK: int = 0
2223
CUDA_VISIBLE_DEVICES: Optional[str] = None
@@ -83,8 +84,9 @@
8384
VLLM_FP8_PADDING: bool = True
8485
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
8586
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
86-
K_SCALE_CONSTANT: int = 200
87-
V_SCALE_CONSTANT: int = 100
87+
Q_SCALE_CONSTANT: int = 20
88+
K_SCALE_CONSTANT: int = 20
89+
V_SCALE_CONSTANT: int = 10
8890

8991

9092
def get_default_cache_root():
@@ -242,13 +244,18 @@ def get_default_config_root():
242244
# custom paged attention implemented for MI3* cards
243245
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN":
244246
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
245-
("true", "1") != "0"),
247+
("true", "1")),
246248

247249
# have custom paged attention implemented for MI3* cards write out fp8
248250
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
249251
lambda:
250252
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in
251-
("true", "1") != "0"),
253+
("true", "1")),
254+
255+
# use quantized q,k,v,softmax(qk^T), attn output during prefill
256+
"VLLM_USE_ROCM_FP8_FLASH_ATTN":
257+
lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in
258+
("true", "1")),
252259

253260
# rank of the process in the distributed setting, used to determine
254261
# the driver worker
@@ -530,13 +537,19 @@ def get_default_config_root():
530537
"VLLM_FP8_PADDING":
531538
lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))),
532539

533-
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
540+
# Divisor for dynamic query scale factor calculation for FP8 attention
541+
"Q_SCALE_CONSTANT":
542+
lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")),
543+
544+
# Divisor for dynamic key scale factor calculation
545+
# for FP8 KV Cache and attention
534546
"K_SCALE_CONSTANT":
535-
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
547+
lambda: int(os.getenv("K_SCALE_CONSTANT", "20")),
536548

537-
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
549+
# Divisor for dynamic value scale factor calculation
550+
# for FP8 KV Cache and attention
538551
"V_SCALE_CONSTANT":
539-
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
552+
lambda: int(os.getenv("V_SCALE_CONSTANT", "10")),
540553

541554
# If set, enable multiprocessing in LLM for the V1 code path.
542555
"VLLM_ENABLE_V1_MULTIPROCESSING":

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
146146
return name.replace(".k_proj.output_scale", ".attn.k_scale")
147147
if name.endswith(".output_scale") and ".v_proj" in name:
148148
return name.replace(".v_proj.output_scale", ".attn.v_scale")
149+
if name.endswith(".output_scale") and ".q_proj" in name:
150+
return name.replace(".q_proj.output_scale", ".attn.q_scale")
151+
if name.endswith("self_attn.prob_output_scale"):
152+
return name.replace(".prob_output_scale", ".attn.prob_scale")
149153
# If no matches, return None
150154
return None
151155

0 commit comments

Comments
 (0)