Skip to content

Commit 6a12481

Browse files
committed
Fixing DCO issue and format checker issue
Co-authored-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: YaoJiayi <1200040070@link.cuhk.edu.cn> Signed-off-by: ApostaC <yihua98@uchicago.edu>
1 parent 1cab43c commit 6a12481

File tree

19 files changed

+844
-43
lines changed

19 files changed

+844
-43
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, SamplingParams
4+
from vllm.config import KVTransferConfig
5+
6+
# Read prompts from output.txt
7+
prompts = []
8+
try:
9+
with open("output.txt") as f:
10+
for line in f:
11+
prompts.append(line.strip())
12+
print(f"Loaded {len(prompts)} prompts from output.txt")
13+
except FileNotFoundError:
14+
print("Error: output.txt file not found")
15+
exit(-1)
16+
17+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
18+
19+
llm = LLM(
20+
model="meta-llama/llama-3.1-8b-instruct",
21+
enforce_eager=True,
22+
gpu_memory_utilization=0.8,
23+
kv_transfer_config=KVTransferConfig.from_cli(
24+
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'
25+
'"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}'
26+
)) #, max_model_len=2048, max_num_batched_tokens=2048)
27+
28+
# 1ST generation (prefill instance)
29+
outputs = llm.generate(prompts, sampling_params)
30+
31+
new_prompts = []
32+
for output in outputs:
33+
prompt = output.prompt
34+
generated_text = output.outputs[0].text
35+
new_prompts.append(prompt + generated_text)
36+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, SamplingParams
4+
from vllm.config import KVTransferConfig
5+
6+
context = "Hi " * 1000
7+
context2 = "Hey " * 500
8+
prompts = [
9+
context + "Hello, my name is",
10+
context + "The capital of France is",
11+
context2 + "Your name is",
12+
context2 + "The capital of China is",
13+
]
14+
15+
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
16+
17+
llm = LLM(model="meta-llama/llama-3.1-8b-instruct",
18+
enforce_eager=True,
19+
gpu_memory_utilization=0.8,
20+
kv_transfer_config=KVTransferConfig.from_cli(
21+
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
22+
'"kv_extra_config": {"shared_storage_path": "local_storage"}}')
23+
) #, max_model_len=2048, max_num_batched_tokens=2048)
24+
25+
# 1ST generation (prefill instance)
26+
outputs = llm.generate(
27+
prompts,
28+
sampling_params,
29+
)
30+
31+
new_prompts = []
32+
for output in outputs:
33+
prompt = output.prompt
34+
generated_text = output.outputs[0].text
35+
new_prompts.append(prompt + generated_text)
36+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
37+
38+
# Write new_prompts to output.txt
39+
with open("output.txt", "w") as f:
40+
for prompt in new_prompts:
41+
f.write(prompt + "\n")
42+
print(f"Saved {len(new_prompts)} prompts to output.txt")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
find /tmp -iname "*attn.pt" 2>/dev/null | cut -d'/' -f1,2,3 | uniq | xargs rm -r
2+
3+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=1 python3 prefill_example.py
4+
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=1 python3 decode_example.py

requirements/test.txt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ anyio==4.6.2.post1
2323
# via httpx
2424
argcomplete==3.5.1
2525
# via datamodel-code-generator
26+
async-timeout==5.0.1
27+
# via
28+
# aiohttp
29+
# redis
2630
attrs==24.2.0
2731
# via
2832
# aiohttp
@@ -117,6 +121,10 @@ encodec==0.1.1
117121
# via vocos
118122
evaluate==0.4.3
119123
# via lm-eval
124+
exceptiongroup==1.2.2
125+
# via
126+
# anyio
127+
# pytest
120128
fastparquet==2024.11.0
121129
# via genai-perf
122130
fastrlock==0.8.2
@@ -556,9 +564,7 @@ sentence-transformers==3.2.1
556564
sentencepiece==0.2.0
557565
# via mistral-common
558566
setuptools==75.8.0
559-
# via
560-
# pytablewriter
561-
# torch
567+
# via pytablewriter
562568
shellingham==1.5.4
563569
# via typer
564570
six==1.16.0
@@ -605,6 +611,12 @@ timm==1.0.11
605611
# via -r requirements/test.in
606612
tokenizers==0.21.0
607613
# via transformers
614+
toml==0.10.2
615+
# via datamodel-code-generator
616+
tomli==2.2.1
617+
# via
618+
# black
619+
# pytest
608620
torch==2.6.0
609621
# via
610622
# -r requirements/test.in
@@ -670,12 +682,16 @@ typer==0.15.2
670682
# via fastsafetensors
671683
typing-extensions==4.12.2
672684
# via
685+
# anyio
686+
# black
673687
# huggingface-hub
674688
# librosa
675689
# mistral-common
690+
# multidict
676691
# pqdm
677692
# pydantic
678693
# pydantic-core
694+
# rich
679695
# torch
680696
# typer
681697
tzdata==2024.2

vllm/attention/layer.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.attention import AttentionType
1111
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
1212
from vllm.config import CacheConfig, get_current_vllm_config
13+
from vllm.distributed import get_kv_transfer_group
1314
from vllm.forward_context import ForwardContext, get_forward_context
1415
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
1516
from vllm.model_executor.layers.quantization.base_config import (
@@ -179,6 +180,7 @@ def forward(
179180
context using
180181
`vllm.forward_context.get_forward_context().attn_metadata`.
181182
"""
183+
get_kv_transfer_group().wait_for_layer_load(self.layer_name)
182184
if self.calculate_kv_scales:
183185
attn_metadata = get_forward_context().attn_metadata
184186
if attn_metadata.enable_kv_scales_calculation:
@@ -214,20 +216,26 @@ def forward(
214216
self_kv_cache,
215217
attn_metadata,
216218
output=output)
219+
save_kv_layer_to_connector(self.layer_name, self.kv_cache)
217220
else:
218221
torch.ops.vllm.unified_attention_with_output(
219222
query, key, value, output, self.layer_name)
223+
save_kv_layer_to_connector(self.layer_name, self.kv_cache)
220224
return output.view(-1, hidden_size)
221225
else:
222226
if self.use_direct_call:
223227
forward_context = get_forward_context()
224228
attn_metadata = forward_context.attn_metadata
225229
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
226-
return self.impl.forward(self, query, key, value,
227-
self_kv_cache, attn_metadata)
230+
output = self.impl.forward(self, query, key, value,
231+
self_kv_cache, attn_metadata)
232+
save_kv_layer_to_connector(self.layer_name, self.kv_cache)
233+
return output
228234
else:
229-
return torch.ops.vllm.unified_attention(
235+
output = torch.ops.vllm.unified_attention(
230236
query, key, value, self.layer_name)
237+
save_kv_layer_to_connector(self.layer_name, self.kv_cache)
238+
return output
231239

232240
def calc_kv_scales(self, query, key, value):
233241
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
@@ -329,6 +337,23 @@ def forward(
329337
return out.reshape(bsz, q_len, -1)
330338

331339

340+
def save_kv_layer_to_connector(
341+
layer_name: str,
342+
kv_cache: List[torch.Tensor],
343+
):
344+
forward_context: ForwardContext = get_forward_context()
345+
attn_metadata = forward_context.attn_metadata
346+
if attn_metadata is None:
347+
return
348+
349+
connector = get_kv_transfer_group()
350+
if connector is None:
351+
return
352+
353+
kv_cache_layer = kv_cache[forward_context.virtual_engine]
354+
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
355+
356+
332357
def unified_attention(
333358
query: torch.Tensor,
334359
key: torch.Tensor,

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import importlib
4-
from typing import TYPE_CHECKING, Callable, Dict, Type
4+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
5+
6+
import vllm.envs as envs
7+
# NOTE(Kuntai): We prefer not to directly the classes with "_V1" suffix.
8+
# This makes it easier for us to deprecate code in v0 (which will happen soon).
9+
# yapf: disable
10+
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
11+
KVConnectorRole)
12+
# yapf: enable
13+
from vllm.logger import init_logger
514

615
from .base import KVConnectorBase
716

817
if TYPE_CHECKING:
918
from vllm.config import VllmConfig
1019

20+
logger = init_logger(__name__)
21+
1122

1223
class KVConnectorFactory:
13-
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
24+
_registry: Dict[str, Callable[[], Type[Union[KVConnectorBase,
25+
KVConnectorBase_V1]]]] = {}
1426

1527
@classmethod
1628
def register_connector(cls, name: str, module_path: str,
@@ -19,21 +31,41 @@ def register_connector(cls, name: str, module_path: str,
1931
if name in cls._registry:
2032
raise ValueError(f"Connector '{name}' is already registered.")
2133

22-
def loader() -> Type[KVConnectorBase]:
34+
def loader() -> Type[Union[KVConnectorBase, KVConnectorBase_V1]]:
2335
module = importlib.import_module(module_path)
2436
return getattr(module, class_name)
2537

2638
cls._registry[name] = loader
2739

2840
@classmethod
29-
def create_connector(cls, rank: int, local_rank: int,
30-
config: "VllmConfig") -> KVConnectorBase:
41+
def create_connector(
42+
cls, rank: Optional[int], local_rank: Optional[int],
43+
config: "VllmConfig", role: KVConnectorRole
44+
) -> Union[KVConnectorBase, KVConnectorBase_V1]:
3145
connector_name = config.kv_transfer_config.kv_connector
3246
if connector_name not in cls._registry:
3347
raise ValueError(f"Unsupported connector type: {connector_name}")
3448

35-
connector_cls = cls._registry[connector_name]()
36-
return connector_cls(rank, local_rank, config)
49+
if envs.VLLM_USE_V1:
50+
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
51+
# Scheduler connector:
52+
# - Co-colate with scheduler process
53+
# - Should only be used inside the Scheduler class
54+
# Worker connector:
55+
# - Co-locate with worker process
56+
# - Should only be used inside the forward context & attention layer
57+
# We build these two connectors separately to enforce strict
58+
# separation
59+
connector_cls_v1 = cls._registry[connector_name]()
60+
assert issubclass(connector_cls_v1, KVConnectorBase_V1)
61+
logger.info("Creating v1 connector with name: %s", connector_name)
62+
return connector_cls_v1(rank, local_rank, config, role)
63+
else:
64+
assert rank is not None
65+
assert local_rank is not None
66+
connector_cls = cls._registry[connector_name]()
67+
assert issubclass(connector_cls, KVConnectorBase)
68+
return connector_cls(rank, local_rank, config)
3769

3870

3971
# Register various connectors here.
@@ -57,4 +89,9 @@ def create_connector(cls, rank: int, local_rank: int,
5789
KVConnectorFactory.register_connector(
5890
"MooncakeStoreConnector",
5991
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
60-
"MooncakeStoreConnector")
92+
"MooncakeStoreConnector")
93+
94+
KVConnectorFactory.register_connector(
95+
"SharedStorageConnector",
96+
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
97+
"SharedStorageConnector")
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# yapf: disable
3+
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
4+
KVConnectorBase_V1, KVConnectorRole)
5+
6+
# yapf: enable
7+
8+
__all__ = [
9+
"KVConnectorRole",
10+
"KVConnectorBase_V1",
11+
]

0 commit comments

Comments
 (0)