11# SPDX-License-Identifier: Apache-2.0
22
33import importlib
4- from typing import TYPE_CHECKING , Callable , Dict , Type
4+ from typing import TYPE_CHECKING , Callable , Dict , Optional , Type , Union
5+
6+ import vllm .envs as envs
7+ # NOTE(Kuntai): We prefer not to directly the classes with "_V1" suffix.
8+ # This makes it easier for us to deprecate code in v0 (which will happen soon).
9+ # yapf: disable
10+ from vllm .distributed .kv_transfer .kv_connector .v1 import (KVConnectorBase_V1 ,
11+ KVConnectorRole )
12+ # yapf: enable
13+ from vllm .logger import init_logger
514
615from .base import KVConnectorBase
716
817if TYPE_CHECKING :
918 from vllm .config import VllmConfig
1019
20+ logger = init_logger (__name__ )
21+
1122
1223class KVConnectorFactory :
13- _registry : Dict [str , Callable [[], Type [KVConnectorBase ]]] = {}
24+ _registry : Dict [str , Callable [[], Type [Union [KVConnectorBase ,
25+ KVConnectorBase_V1 ]]]] = {}
1426
1527 @classmethod
1628 def register_connector (cls , name : str , module_path : str ,
@@ -19,21 +31,41 @@ def register_connector(cls, name: str, module_path: str,
1931 if name in cls ._registry :
2032 raise ValueError (f"Connector '{ name } ' is already registered." )
2133
22- def loader () -> Type [KVConnectorBase ]:
34+ def loader () -> Type [Union [ KVConnectorBase , KVConnectorBase_V1 ] ]:
2335 module = importlib .import_module (module_path )
2436 return getattr (module , class_name )
2537
2638 cls ._registry [name ] = loader
2739
2840 @classmethod
29- def create_connector (cls , rank : int , local_rank : int ,
30- config : "VllmConfig" ) -> KVConnectorBase :
41+ def create_connector (
42+ cls , rank : Optional [int ], local_rank : Optional [int ],
43+ config : "VllmConfig" , role : KVConnectorRole
44+ ) -> Union [KVConnectorBase , KVConnectorBase_V1 ]:
3145 connector_name = config .kv_transfer_config .kv_connector
3246 if connector_name not in cls ._registry :
3347 raise ValueError (f"Unsupported connector type: { connector_name } " )
3448
35- connector_cls = cls ._registry [connector_name ]()
36- return connector_cls (rank , local_rank , config )
49+ if envs .VLLM_USE_V1 :
50+ # NOTE(Kuntai): v1 connector is explicitly separated into two roles.
51+ # Scheduler connector:
52+ # - Co-colate with scheduler process
53+ # - Should only be used inside the Scheduler class
54+ # Worker connector:
55+ # - Co-locate with worker process
56+ # - Should only be used inside the forward context & attention layer
57+ # We build these two connectors separately to enforce strict
58+ # separation
59+ connector_cls_v1 = cls ._registry [connector_name ]()
60+ assert issubclass (connector_cls_v1 , KVConnectorBase_V1 )
61+ logger .info ("Creating v1 connector with name: %s" , connector_name )
62+ return connector_cls_v1 (rank , local_rank , config , role )
63+ else :
64+ assert rank is not None
65+ assert local_rank is not None
66+ connector_cls = cls ._registry [connector_name ]()
67+ assert issubclass (connector_cls , KVConnectorBase )
68+ return connector_cls (rank , local_rank , config )
3769
3870
3971# Register various connectors here.
@@ -57,4 +89,9 @@ def create_connector(cls, rank: int, local_rank: int,
5789KVConnectorFactory .register_connector (
5890 "MooncakeStoreConnector" ,
5991 "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector" ,
60- "MooncakeStoreConnector" )
92+ "MooncakeStoreConnector" )
93+
94+ KVConnectorFactory .register_connector (
95+ "SharedStorageConnector" ,
96+ "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector" ,
97+ "SharedStorageConnector" )
0 commit comments