Skip to content

Commit

Permalink
custom model support through litellm
Browse files Browse the repository at this point in the history
  • Loading branch information
rashadphz committed Jun 14, 2024
1 parent 3602c5a commit 3a483fb
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 4 deletions.
1 change: 1 addition & 0 deletions docker-compose-no-searxng.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ services:
- ENABLE_LOCAL_MODELS=${ENABLE_LOCAL_MODELS:-True}
- SEARCH_PROVIDER=${SEARCH_PROVIDER:-tavily}
- SEARXNG_BASE_URL=${SEARXNG_BASE_URL:-http://host.docker.internal:8080}
- CUSTOM_MODEL=${CUSTOM_MODEL}
- REDIS_URL=${REDIS_URL}
develop:
watch:
Expand Down
1 change: 1 addition & 0 deletions docker-compose.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ services:
- ENABLE_LOCAL_MODELS=${ENABLE_LOCAL_MODELS:-True}
- SEARCH_PROVIDER=${SEARCH_PROVIDER:-tavily}
- SEARXNG_BASE_URL=${SEARXNG_BASE_URL:-http://host.docker.internal:8080}
- CUSTOM_MODEL=${CUSTOM_MODEL}
- REDIS_URL=${REDIS_URL}
develop:
watch:
Expand Down
9 changes: 9 additions & 0 deletions src/backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ChatModel(str, Enum):
LOCAL_MISTRAL = "mistral"
LOCAL_PHI3_14B = "phi3:14b"

# Custom models
CUSTOM = "custom"


model_mappings: dict[ChatModel, str] = {
ChatModel.GPT_3_5_TURBO: "gpt-3.5-turbo",
Expand All @@ -30,6 +33,12 @@ class ChatModel(str, Enum):


def get_model_string(model: ChatModel) -> str:
if model == ChatModel.CUSTOM:
custom_model = os.environ.get("CUSTOM_MODEL")
if custom_model is None:
raise ValueError("CUSTOM_MODEL is not set")
return custom_model

if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "azure":
Expand Down
6 changes: 5 additions & 1 deletion src/backend/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dotenv import load_dotenv
from instructor.client import T
from litellm import completion
from litellm.utils import validate_environment
from llama_index.core.base.llms.types import (
CompletionResponse,
CompletionResponseAsyncGen,
Expand Down Expand Up @@ -32,8 +33,11 @@ def __init__(
self,
model: str,
):
self.llm = LiteLLM(model=model)
validation = validate_environment(model)
if validation["missing_keys"]:
raise ValueError(f"Missing keys: {validation['missing_keys']}")

self.llm = LiteLLM(model=model)
self.client = instructor.from_litellm(completion)

async def astream(self, prompt: str) -> CompletionResponseAsyncGen:
Expand Down
1 change: 1 addition & 0 deletions src/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def is_local_model(model: ChatModel) -> bool:
ChatModel.LOCAL_GEMMA,
ChatModel.LOCAL_MISTRAL,
ChatModel.LOCAL_PHI3_14B,
ChatModel.CUSTOM,
]


Expand Down
2 changes: 2 additions & 0 deletions src/frontend/generated/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ export const $ChatModel = {
"llama3",
"gemma",
"mistral",
"phi3:14b",
"custom",
],
title: "ChatModel",
} as const;
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/generated/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export enum ChatModel {
LLAMA3 = "llama3",
GEMMA = "gemma",
MISTRAL = "mistral",
LOCAL_PHI3_14B = "phi3:14b",
PHI3_14B = "phi3:14b",
CUSTOM = "custom",
}

export type ChatRequest = {
Expand Down
12 changes: 10 additions & 2 deletions src/frontend/src/components/model-selection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
FlameIcon,
Rabbit,
RabbitIcon,
SettingsIcon,
SparklesIcon,
WandSparklesIcon,
} from "lucide-react";
Expand Down Expand Up @@ -75,13 +76,20 @@ const modelMap: Record<ChatModel, Model> = {
smallIcon: <AtomIcon className="w-4 h-4 text-[#FF7000]" />,
icon: <AtomIcon className="w-5 h-5 text-[#FF7000]" />,
},
[ChatModel.LOCAL_PHI3_14B]: {
[ChatModel.PHI3_14B]: {
name: "Phi3",
description: "ollama/phi3:14b",
value: ChatModel.LOCAL_PHI3_14B,
value: ChatModel.PHI3_14B,
smallIcon: <FlameIcon className="w-4 h-4 text-green-500" />,
icon: <FlameIcon className="w-5 h-5 text-green-500" />,
},
[ChatModel.CUSTOM]: {
name: "Custom",
description: "Custom model",
value: ChatModel.CUSTOM,
smallIcon: <SettingsIcon className="w-4 h-4 text-red-500" />,
icon: <SettingsIcon className="w-5 h-5 text-red-500" />,
},
};

const localModelMap: Partial<Record<ChatModel, Model>> = _.pickBy(
Expand Down

0 comments on commit 3a483fb

Please sign in to comment.