Skip to content

Commit dcd4721

Browse files
Merge branch 'vllm-project:main' into main
2 parents 790ad45 + eb6d3c2 commit dcd4721

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+666
-372
lines changed

benchmarks/benchmark_latency.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
153153
action='store_true',
154154
help='enforce eager mode and disable CUDA graph')
155155
parser.add_argument(
156-
"--kv-cache-dtype",
156+
'--kv-cache-dtype',
157157
type=str,
158-
choices=['auto', 'fp8'],
159-
default='auto',
160-
help=
161-
'Data type for kv cache storage. If "auto", will use model data type. '
162-
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
163-
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
164-
'instead supported for common inference criteria.')
158+
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
159+
default="auto",
160+
help='Data type for kv cache storage. If "auto", will use model '
161+
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
162+
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
165163
parser.add_argument(
166164
'--quantization-param-path',
167165
type=str,

benchmarks/benchmark_throughput.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -323,15 +323,13 @@ def main(args: argparse.Namespace):
323323
action="store_true",
324324
help="enforce eager execution")
325325
parser.add_argument(
326-
"--kv-cache-dtype",
326+
'--kv-cache-dtype',
327327
type=str,
328-
choices=["auto", "fp8"],
328+
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
329329
default="auto",
330-
help=
331-
'Data type for kv cache storage. If "auto", will use model data type. '
332-
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
333-
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
334-
'common inference criteria.')
330+
help='Data type for kv cache storage. If "auto", will use model '
331+
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
332+
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
335333
parser.add_argument(
336334
'--quantization-param-path',
337335
type=str,

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
183183
parser.add_argument(
184184
"--kv-cache-dtype",
185185
type=str,
186-
choices=["auto", "fp8"],
186+
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
187187
default="auto",
188-
help=
189-
'Data type for kv cache storage. If "auto", will use model data type. '
190-
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
191-
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
192-
'common inference criteria.')
188+
help="Data type for kv cache storage. If 'auto', will use model "
189+
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
190+
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
193191
args = parser.parse_args()
194192
print(args)
195193

csrc/punica/bgmv/bgmv_config.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
2828
f(in_T, out_T, W_T, narrow, 2752) \
2929
f(in_T, out_T, W_T, narrow, 2816) \
3030
f(in_T, out_T, W_T, narrow, 3072) \
31+
f(in_T, out_T, W_T, narrow, 3328) \
3132
f(in_T, out_T, W_T, narrow, 3456) \
3233
f(in_T, out_T, W_T, narrow, 3584) \
3334
f(in_T, out_T, W_T, narrow, 4096) \
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
3637
f(in_T, out_T, W_T, narrow, 5504) \
3738
f(in_T, out_T, W_T, narrow, 5632) \
3839
f(in_T, out_T, W_T, narrow, 6144) \
40+
f(in_T, out_T, W_T, narrow, 6400) \
3941
f(in_T, out_T, W_T, narrow, 6848) \
4042
f(in_T, out_T, W_T, narrow, 6912) \
4143
f(in_T, out_T, W_T, narrow, 7168) \
@@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
9799
f(in_T, out_T, W_T, 2752, narrow) \
98100
f(in_T, out_T, W_T, 2816, narrow) \
99101
f(in_T, out_T, W_T, 3072, narrow) \
102+
f(in_T, out_T, W_T, 3328, narrow) \
100103
f(in_T, out_T, W_T, 3456, narrow) \
101104
f(in_T, out_T, W_T, 3584, narrow) \
102105
f(in_T, out_T, W_T, 4096, narrow) \
@@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
105108
f(in_T, out_T, W_T, 5504, narrow) \
106109
f(in_T, out_T, W_T, 5632, narrow) \
107110
f(in_T, out_T, W_T, 6144, narrow) \
111+
f(in_T, out_T, W_T, 6400, narrow) \
108112
f(in_T, out_T, W_T, 6848, narrow) \
109113
f(in_T, out_T, W_T, 6912, narrow) \
110114
f(in_T, out_T, W_T, 7168, narrow) \

tests/lora/test_punica.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _lora_ref_impl(
5858
2560,
5959
2752,
6060
3072,
61+
3328,
6162
3456,
6263
3584,
6364
4096,
@@ -66,6 +67,7 @@ def _lora_ref_impl(
6667
5504,
6768
5632,
6869
6144,
70+
6400,
6971
6848,
7072
6912,
7173
7168,

tests/models/test_fp8.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,55 @@
1616
MAX_MODEL_LEN = 1024
1717

1818
MODELS = [
19-
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
19+
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
2020
"meta-llama/Meta-Llama-3-8B-Instruct",
2121
]
2222

2323
EXPECTED_STRS_MAP = {
24-
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
25-
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
26-
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
27-
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
28-
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
29-
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
30-
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
31-
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
32-
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
33-
],
34-
"meta-llama/Meta-Llama-3-8B-Instruct": [
35-
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
36-
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
37-
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
38-
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
39-
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
40-
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
41-
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
42-
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
43-
],
24+
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
25+
"auto": [
26+
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
27+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
28+
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
29+
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
30+
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
31+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
32+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
33+
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
34+
],
35+
"fp8": [
36+
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
37+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
38+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
39+
'A neural network is a complex system made up of several basic components that work together to enable it to',
40+
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
41+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
42+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
43+
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
44+
]
45+
},
46+
"meta-llama/Meta-Llama-3-8B-Instruct": {
47+
"auto": [
48+
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
49+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
50+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
51+
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
52+
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
53+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
54+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
55+
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
56+
],
57+
"fp8": [
58+
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
59+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
60+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
61+
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
62+
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
63+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
64+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
65+
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
66+
]
67+
},
4468
}
4569

4670
capability = torch.cuda.get_device_capability()
@@ -52,14 +76,14 @@
5276
@pytest.mark.skipif(fp8_not_supported,
5377
reason="fp8 is not supported on this GPU type.")
5478
@pytest.mark.parametrize("model_name", MODELS)
55-
def test_models(
56-
example_prompts,
57-
model_name,
58-
) -> None:
79+
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
80+
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
5981
model = LLM(model=model_name,
6082
max_model_len=MAX_MODEL_LEN,
83+
trust_remote_code=True,
6184
enforce_eager=True,
62-
quantization="fp8")
85+
quantization="fp8",
86+
kv_cache_dtype=kv_cache_dtype)
6387

6488
tokenizer = AutoTokenizer.from_pretrained(model_name)
6589
formatted_prompts = [
@@ -81,8 +105,8 @@ def test_models(
81105
generations.append(outputs[0].outputs[0].text)
82106
del model
83107

84-
print(generations)
85-
expected_strs = EXPECTED_STRS_MAP[model_name]
108+
print(model_name, kv_cache_dtype, generations)
109+
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
86110
for i in range(len(example_prompts)):
87111
generated_str = generations[i]
88112
expected_str = expected_strs[i]

vllm/attention/layer.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from vllm.attention.backends.abstract import AttentionMetadata
88
from vllm.attention.selector import get_attn_backend
99
from vllm.config import CacheConfig
10+
from vllm.model_executor.layers.quantization.base_config import (
11+
QuantizationConfig)
1012

1113

1214
class Attention(nn.Module):
@@ -30,6 +32,7 @@ def __init__(
3032
alibi_slopes: Optional[List[float]] = None,
3133
sliding_window: Optional[int] = None,
3234
cache_config: Optional[CacheConfig] = None,
35+
quant_config: Optional[QuantizationConfig] = None,
3336
) -> None:
3437
super().__init__()
3538
if cache_config is not None:
@@ -40,6 +43,27 @@ def __init__(
4043
block_size = 16
4144
if num_kv_heads is None:
4245
num_kv_heads = num_heads
46+
47+
# The default kv_scale is set to 1.0. This is ignored
48+
# when kv-cache is not fp8, and should be used with
49+
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
50+
# expect the pre-quantized kv_scale to be loaded along
51+
# with the model weights.
52+
self.kv_cache_dtype = kv_cache_dtype
53+
self._kv_scale = 1.0
54+
quant_method = quant_config.get_quant_method(
55+
self) if quant_config else None
56+
if quant_method is not None:
57+
if self.kv_cache_dtype == "fp8_e5m2":
58+
raise ValueError("fp8_e5m2 kv-cache is not supported with "
59+
"fp8 checkpoints.")
60+
# When FP8 quantization is enabled, we make a parameter
61+
# "kv_scale" so that it can be loaded from FP8 checkpoint.
62+
# The kv_scale will then be converted back
63+
# to self._kv_scale in a native float32 value after weight loading.
64+
self.quant_method = quant_method
65+
self.quant_method.create_weights(self)
66+
4367
# During model initialization, the default dtype is set as the model
4468
# weight and activation dtype.
4569
dtype = torch.get_default_dtype()
@@ -57,10 +81,9 @@ def forward(
5781
value: torch.Tensor,
5882
kv_cache: Optional[torch.Tensor],
5983
attn_metadata: AttentionMetadata,
60-
kv_scale: float = 1.0,
6184
) -> torch.Tensor:
6285
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
63-
kv_scale)
86+
self._kv_scale)
6487

6588
def extra_repr(self) -> str:
6689
s = f"head_size={self.impl.head_size}" # type: ignore

vllm/config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,12 @@ def _verify_args(self) -> None:
355355
def _verify_cache_dtype(self) -> None:
356356
if self.cache_dtype == "auto":
357357
pass
358-
elif self.cache_dtype == "fp8":
358+
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
359359
logger.info(
360360
"Using fp8 data type to store kv cache. It reduces the GPU "
361361
"memory footprint and boosts the performance. "
362-
"But it may cause slight accuracy drop without scaling "
363-
"factors. FP8_E5M2 (without scaling) is only supported on "
364-
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
365-
"is instead supported for common inference criteria.")
362+
"Meanwhile, it may cause accuracy drop without a proper "
363+
"scaling factor")
366364
else:
367365
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
368366

vllm/engine/arg_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,11 @@ def add_cli_args(
191191
parser.add_argument(
192192
'--kv-cache-dtype',
193193
type=str,
194-
choices=['auto', 'fp8'],
194+
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
195195
default=EngineArgs.kv_cache_dtype,
196196
help='Data type for kv cache storage. If "auto", will use model '
197-
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
198-
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
199-
'supported for common inference criteria.')
197+
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
198+
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
200199
parser.add_argument(
201200
'--quantization-param-path',
202201
type=nullable_str,

vllm/engine/async_llm_engine.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,14 @@ async def step_async(
234234
# Log stats.
235235
self.do_log_stats(scheduler_outputs, output)
236236

237+
if not request_outputs:
238+
# Stop the execute model loop in parallel workers until there are
239+
# more requests to process. This avoids waiting indefinitely in
240+
# torch.distributed ops which may otherwise timeout, and unblocks
241+
# the RPC thread in the workers so that they can process any other
242+
# queued control plane messages, such as add/remove lora adapters.
243+
await self.model_executor.stop_remote_worker_execution_loop_async()
244+
237245
return request_outputs
238246

239247
async def encode_request_async(
@@ -687,7 +695,7 @@ async def encode(
687695
multi_modal_data: Multi modal data per request.
688696
689697
Yields:
690-
The output `EmbeddingRequestOutput` objects from the LLMEngine
698+
The output `EmbeddingRequestOutput` objects from the LLMEngine
691699
for the request.
692700
693701
Details:

vllm/engine/llm_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
692692
# Log stats.
693693
self.do_log_stats(scheduler_outputs, output)
694694

695+
if not request_outputs:
696+
# Stop the execute model loop in parallel workers until there are
697+
# more requests to process. This avoids waiting indefinitely in
698+
# torch.distributed ops which may otherwise timeout, and unblocks
699+
# the RPC thread in the workers so that they can process any other
700+
# queued control plane messages, such as add/remove lora adapters.
701+
self.model_executor.stop_remote_worker_execution_loop()
702+
695703
return request_outputs
696704

697705
def do_log_stats(

0 commit comments

Comments
 (0)