Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sambanova Inference #518

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions distributions/dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -311,5 +311,11 @@
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"sambanova": [
"aiohttp",
"aiosqlite",
"fastapi",
"openai"
]
}
1 change: 1 addition & 0 deletions distributions/ssambanova/build.yaml
16 changes: 16 additions & 0 deletions distributions/ssambanova/compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
services:
llamastack:
image: llamastack/distribution-sambanova
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/llamastack-run-sambanova.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-sambanova.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s
71 changes: 71 additions & 0 deletions distributions/ssambanova/run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
version: "2"
image_name: sambanova
docker_image: null
conda_env: sambanova
apis:
- inference
- safety
- agents
- memory
- datasetio
- scoring
- eval
- telemetry
providers:
inference:
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_token: ${env.SAMBANOVA_API_KEY}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store: null
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-8B-Instruct
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
11 changes: 11 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambanovaImplConfig",
),
),
]
17 changes: 17 additions & 0 deletions llama_stack/providers/remote/inference/sambanova/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .config import SambanovaImplConfig
from .sambanova import SambanovaInferenceAdapter


async def get_adapter_impl(config: SambanovaImplConfig, _deps):
assert isinstance(
config, SambanovaImplConfig
), f"Unexpected config type: {type(config)}"
impl = SambanovaInferenceAdapter(config)
await impl.initialize()
return impl
21 changes: 21 additions & 0 deletions llama_stack/providers/remote/inference/sambanova/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class SambanovaImplConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the Sambanova model serving endpoint",
)
api_token: str = Field(
default=None,
description="The Sambanova API token",
)
153 changes: 153 additions & 0 deletions llama_stack/providers/remote/inference/sambanova/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import AsyncGenerator

from llama_models.datatypes import CoreModelId

from llama_models.llama3.api.chat_format import ChatFormat

from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer

from openai import OpenAI

from llama_stack.apis.inference import * # noqa: F403

from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)

from .config import SambanovaImplConfig


model_aliases = [
build_model_alias(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
]


class SambanovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambanovaImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())

async def initialize(self) -> None:
return

async def shutdown(self) -> None:
pass

async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)

client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)

async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk

stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk

def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
5 changes: 4 additions & 1 deletion llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def text_from_choice(choice) -> str:
return choice.delta.content

if hasattr(choice, "message"):
return choice.message.content
try:
return choice.message.content
except:
return choice.text

return choice.text

Expand Down
7 changes: 7 additions & 0 deletions llama_stack/templates/sambanova/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .sambanova import get_distribution_template # noqa: F401
17 changes: 17 additions & 0 deletions llama_stack/templates/sambanova/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: "2"
name: sambanova
distribution_spec:
description: Use Sambanova for running LLM inference
docker_image: null
providers:
inference:
- remote::sambanova
memory:
- inline::faiss
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda
Loading