Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions sdks/python/apache_beam/ml/inference/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,20 @@ def getAsyncVLLMClient(port) -> AsyncOpenAI:


class _VLLMModelServer():
def __init__(self, model_name: str, vllm_server_kwargs: dict[str, str]):
def __init__(
self,
model_name: str,
vllm_server_kwargs: dict[str, str],
vllm_executable: Optional[str] = None):
self._model_name = model_name
self._vllm_server_kwargs = vllm_server_kwargs
self._server_started = False
self._server_process = None
self._server_port: int = -1
self._server_process_lock = threading.RLock()
self._vllm_executable = 'vllm.entrypoints.openai.api_server'
if vllm_executable is not None:
self._vllm_executable = vllm_executable

self.start_server()

Expand All @@ -125,7 +132,7 @@ def start_server(self, retries=3):
server_cmd = [
sys.executable,
'-m',
'vllm.entrypoints.openai.api_server',
self._vllm_executable,
'--model',
self._model_name,
'--port',
Expand Down Expand Up @@ -175,7 +182,8 @@ class VLLMCompletionsModelHandler(ModelHandler[str,
def __init__(
self,
model_name: str,
vllm_server_kwargs: Optional[dict[str, str]] = None):
vllm_server_kwargs: Optional[dict[str, str]] = None,
use_dynamo: bool = False):
"""Implementation of the ModelHandler interface for vLLM using text as
input.

Expand All @@ -194,13 +202,22 @@ def __init__(
`{'echo': 'true'}` to prepend new messages with the previous message.
For a list of possible kwargs, see
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api
use_dynamo: Whether to use Nvidia Dynamo as the underlying vLLM engine.
Requires installing dynamo in your runtime environment
(`pip install ai-dynamo[vllm]`)
"""
self._model_name = model_name
self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {}
self._env_vars = {}
self._vllm_executable = None
if use_dynamo:
self._vllm_executable = 'dynamo.vllm'

def load_model(self) -> _VLLMModelServer:
return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
return _VLLMModelServer(
self._model_name,
self._vllm_server_kwargs,
self._vllm_executable)

async def _async_run_inference(
self,
Expand Down Expand Up @@ -253,7 +270,8 @@ def __init__(
self,
model_name: str,
chat_template_path: Optional[str] = None,
vllm_server_kwargs: Optional[dict[str, str]] = None):
vllm_server_kwargs: Optional[dict[str, str]] = None,
use_dynamo: bool = False):
""" Implementation of the ModelHandler interface for vLLM using previous
messages as input.

Expand All @@ -277,12 +295,17 @@ def __init__(
`{'echo': 'true'}` to prepend new messages with the previous message.
For a list of possible kwargs, see
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api
use_dynamo: Whether to use Nvidia Dynamo as the underlying vLLM engine.
Requires installing dynamo in your runtime environment
(`pip install ai-dynamo[vllm]`)
"""
self._model_name = model_name
self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {}
self._env_vars = {}
self._chat_template_path = chat_template_path
self._chat_file = f'template-{uuid.uuid4().hex}.jinja'
if use_dynamo:
self._vllm_executable = 'dynamo.vllm'

def load_model(self) -> _VLLMModelServer:
chat_template_contents = ''
Expand All @@ -295,7 +318,10 @@ def load_model(self) -> _VLLMModelServer:
f.write(chat_template_contents)
self._vllm_server_kwargs['chat_template'] = local_chat_template_path

return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
return _VLLMModelServer(
self._model_name,
self._vllm_server_kwargs,
self._vllm_executable)

async def _async_run_inference(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ pillow>=8.0.0
transformers>=4.18.0
google-cloud-monitoring>=2.27.0
openai>=1.52.2
ai-dynamo[vllm]>=0.1.1
Loading