From a5251b87df19c84b7e7d611f09e772b1206417ee Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Tue, 2 Dec 2025 16:23:29 -0500 Subject: [PATCH] [WIP] Add vllm Dynamo support --- .../ml/inference/vllm_inference.py | 38 ++++++++++++++++--- .../ml/inference/vllm_tests_requirements.txt | 1 + 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index bdbee9e51fd5..6462152629c3 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -109,13 +109,20 @@ def getAsyncVLLMClient(port) -> AsyncOpenAI: class _VLLMModelServer(): - def __init__(self, model_name: str, vllm_server_kwargs: dict[str, str]): + def __init__( + self, + model_name: str, + vllm_server_kwargs: dict[str, str], + vllm_executable: Optional[str] = None): self._model_name = model_name self._vllm_server_kwargs = vllm_server_kwargs self._server_started = False self._server_process = None self._server_port: int = -1 self._server_process_lock = threading.RLock() + self._vllm_executable = 'vllm.entrypoints.openai.api_server' + if vllm_executable is not None: + self._vllm_executable = vllm_executable self.start_server() @@ -125,7 +132,7 @@ def start_server(self, retries=3): server_cmd = [ sys.executable, '-m', - 'vllm.entrypoints.openai.api_server', + self._vllm_executable, '--model', self._model_name, '--port', @@ -175,7 +182,8 @@ class VLLMCompletionsModelHandler(ModelHandler[str, def __init__( self, model_name: str, - vllm_server_kwargs: Optional[dict[str, str]] = None): + vllm_server_kwargs: Optional[dict[str, str]] = None, + use_dynamo: bool = False): """Implementation of the ModelHandler interface for vLLM using text as input. @@ -194,13 +202,22 @@ def __init__( `{'echo': 'true'}` to prepend new messages with the previous message. For a list of possible kwargs, see https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api + use_dynamo: Whether to use Nvidia Dynamo as the underlying vLLM engine. + Requires installing dynamo in your runtime environment + (`pip install ai-dynamo[vllm]`) """ self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} self._env_vars = {} + self._vllm_executable = None + if use_dynamo: + self._vllm_executable = 'dynamo.vllm' def load_model(self) -> _VLLMModelServer: - return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) + return _VLLMModelServer( + self._model_name, + self._vllm_server_kwargs, + self._vllm_executable) async def _async_run_inference( self, @@ -253,7 +270,8 @@ def __init__( self, model_name: str, chat_template_path: Optional[str] = None, - vllm_server_kwargs: Optional[dict[str, str]] = None): + vllm_server_kwargs: Optional[dict[str, str]] = None, + use_dynamo: bool = False): """ Implementation of the ModelHandler interface for vLLM using previous messages as input. @@ -277,12 +295,17 @@ def __init__( `{'echo': 'true'}` to prepend new messages with the previous message. For a list of possible kwargs, see https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api + use_dynamo: Whether to use Nvidia Dynamo as the underlying vLLM engine. + Requires installing dynamo in your runtime environment + (`pip install ai-dynamo[vllm]`) """ self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} self._env_vars = {} self._chat_template_path = chat_template_path self._chat_file = f'template-{uuid.uuid4().hex}.jinja' + if use_dynamo: + self._vllm_executable = 'dynamo.vllm' def load_model(self) -> _VLLMModelServer: chat_template_contents = '' @@ -295,7 +318,10 @@ def load_model(self) -> _VLLMModelServer: f.write(chat_template_contents) self._vllm_server_kwargs['chat_template'] = local_chat_template_path - return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) + return _VLLMModelServer( + self._model_name, + self._vllm_server_kwargs, + self._vllm_executable) async def _async_run_inference( self, diff --git a/sdks/python/apache_beam/ml/inference/vllm_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/vllm_tests_requirements.txt index 0f8c6a6a673d..cd969734230f 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_tests_requirements.txt +++ b/sdks/python/apache_beam/ml/inference/vllm_tests_requirements.txt @@ -20,3 +20,4 @@ pillow>=8.0.0 transformers>=4.18.0 google-cloud-monitoring>=2.27.0 openai>=1.52.2 +ai-dynamo[vllm]>=0.1.1