From 3602c5af0a07827c9339542dfc0b56168c3ca423 Mon Sep 17 00:00:00 2001 From: Rashad Philizaire Date: Fri, 14 Jun 2024 12:17:46 -0700 Subject: [PATCH] using litellm! --- poetry.lock | 43 ++++++++++++++++++- pyproject.toml | 1 + src/backend/chat.py | 73 ++++++++----------------------- src/backend/constants.py | 36 +++++++++------- src/backend/llm/base.py | 50 ++++++++++++++++++++++ src/backend/related_queries.py | 78 ++-------------------------------- 6 files changed, 135 insertions(+), 146 deletions(-) create mode 100644 src/backend/llm/base.py diff --git a/poetry.lock b/poetry.lock index fb66125..040197e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1524,6 +1524,32 @@ mongodb = ["pymongo (>4.1,<5)"] redis = ["redis (>3,!=4.5.2,!=4.5.3,<6.0.0)"] rediscluster = ["redis (>=4.2.0,!=4.5.2,!=4.5.3)"] +[[package]] +name = "litellm" +version = "1.40.10" +description = "Library to easily interface with LLM API providers" +optional = false +python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" +files = [ + {file = "litellm-1.40.10-py3-none-any.whl", hash = "sha256:46b77c49593f4e5e7bd9c1291f5896549b6ff6ebdad457af3ce0f4937bcfc17d"}, + {file = "litellm-1.40.10.tar.gz", hash = "sha256:cdfc86f1de60491cd85155d98d64beec019c35dd0d8c8784037655cb7c0bdddc"}, +] + +[package.dependencies] +aiohttp = "*" +click = "*" +importlib-metadata = ">=6.8.0" +jinja2 = ">=3.1.2,<4.0.0" +openai = ">=1.27.0" +python-dotenv = ">=0.2.0" +requests = ">=2.31.0,<3.0.0" +tiktoken = ">=0.7.0" +tokenizers = "*" + +[package.extras] +extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"] +proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] + [[package]] name = "llama-index" version = "0.10.37" @@ -1719,6 +1745,21 @@ files = [ llama-index-core = ">=0.10.1,<0.11.0" llama-index-llms-openai-like = ">=0.1.3,<0.2.0" +[[package]] +name = "llama-index-llms-litellm" +version = "0.1.4" +description = "llama-index llms litellm integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_llms_litellm-0.1.4-py3-none-any.whl", hash = "sha256:5a4150ffcad38c874ba5d58aca7a53d85f4faede67ecfe93f40cac4416f90725"}, + {file = "llama_index_llms_litellm-0.1.4.tar.gz", hash = "sha256:5a8a396de97b779f96e21d1d2b8554a67904ec1daf461b486ed26edc2be74508"}, +] + +[package.dependencies] +litellm = ">=1.18.13,<2.0.0" +llama-index-core = ">=0.10.1,<0.11.0" + [[package]] name = "llama-index-llms-ollama" version = "0.1.3" @@ -5059,4 +5100,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "27c3c5ee924aa0079b97197ac70089876fa251f95528e1620ee2574d5bfee8de" +content-hash = "aec7c9df55261292eeffe9c97b6f7d80b88ea785808949a546a6a1cf574144e7" diff --git a/pyproject.toml b/pyproject.toml index 94a6238..4113a08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ groq = "^0.5.0" slowapi = "^0.1.9" redis = "^5.0.4" llama-index-llms-ollama = "^0.1.3" +llama-index-llms-litellm = "^0.1.4" [tool.poetry.group.dev.dependencies] pre-commit = "^3.7.1" diff --git a/src/backend/chat.py b/src/backend/chat.py index 257e73c..a9c7213 100644 --- a/src/backend/chat.py +++ b/src/backend/chat.py @@ -1,15 +1,10 @@ import asyncio -import os from typing import AsyncIterator, List from fastapi import HTTPException -from llama_index.core.llms import LLM -from llama_index.llms.azure_openai import AzureOpenAI -from llama_index.llms.groq import Groq -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI -from backend.constants import ChatModel, model_mappings +from backend.constants import get_model_string +from backend.llm.base import BaseLLM, EveryLLM from backend.prompts import CHAT_PROMPT, HISTORY_QUERY_REPHRASE from backend.related_queries import generate_related_queries from backend.schemas import ( @@ -29,16 +24,18 @@ from backend.utils import is_local_model -def rephrase_query_with_history(question: str, history: List[Message], llm: LLM) -> str: +def rephrase_query_with_history( + question: str, history: List[Message], llm: BaseLLM +) -> str: + if not history: + return question + try: - if history: - history_str = "\n".join([f"{msg.role}: {msg.content}" for msg in history]) - question = llm.complete( - HISTORY_QUERY_REPHRASE.format( - chat_history=history_str, question=question - ) - ).text - question = question.replace('"', "") + history_str = "\n".join(f"{msg.role}: {msg.content}" for msg in history) + formatted_query = HISTORY_QUERY_REPHRASE.format( + chat_history=history_str, question=question + ) + question = llm.complete(formatted_query).text.replace('"', "") return question except Exception: raise HTTPException( @@ -46,42 +43,6 @@ def rephrase_query_with_history(question: str, history: List[Message], llm: LLM) ) -def get_openai_model(model: ChatModel) -> LLM: - openai_mode = os.environ.get("OPENAI_MODE", "openai") - if openai_mode == "azure": - return AzureOpenAI( - deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME"), - api_key=os.environ.get("AZURE_API_KEY"), - azure_endpoint=os.environ.get("AZURE_CHAT_ENDPOINT"), - api_version="2024-04-01-preview", - ) - elif openai_mode == "openai": - return OpenAI(model=model_mappings[model]) - else: - raise ValueError(f"Unknown model: {model}") - - -def get_llm(model: ChatModel) -> LLM: - if model == ChatModel.GPT_3_5_TURBO: - return get_openai_model(model) - elif model == ChatModel.GPT_4o: - return OpenAI(model=model_mappings[model]) - elif model in [ - ChatModel.LOCAL_GEMMA, - ChatModel.LOCAL_LLAMA_3, - ChatModel.LOCAL_MISTRAL, - ChatModel.LOCAL_PHI3_14B, - ]: - return Ollama( - base_url=os.environ.get("OLLAMA_HOST", "http://localhost:11434"), - model=model_mappings[model], - ) - elif model == ChatModel.LLAMA_3_70B: - return Groq(model=model_mappings[model]) - else: - raise ValueError(f"Unknown model: {model}") - - def format_context(search_results: List[SearchResult]) -> str: return "\n\n".join( [f"Citation {i+1}. {str(result)}" for i, result in enumerate(search_results)] @@ -90,7 +51,7 @@ def format_context(search_results: List[SearchResult]) -> str: async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseEvent]: try: - llm = get_llm(request.model) + llm = EveryLLM(model=get_model_string(request.model)) yield ChatResponseEvent( event=StreamEvent.BEGIN_STREAM, @@ -108,7 +69,7 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE related_queries_task = None if not is_local_model(request.model): related_queries_task = asyncio.create_task( - generate_related_queries(query, search_results, request.model) + generate_related_queries(query, search_results, llm) ) yield ChatResponseEvent( @@ -125,7 +86,7 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE ) full_response = "" - response_gen = await llm.astream_complete(fmt_qa_prompt) + response_gen = await llm.astream(fmt_qa_prompt) async for completion in response_gen: full_response += completion.delta or "" yield ChatResponseEvent( @@ -136,7 +97,7 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE related_queries = await ( related_queries_task if related_queries_task - else generate_related_queries(query, search_results, request.model) + else generate_related_queries(query, search_results, llm) ) yield ChatResponseEvent( diff --git a/src/backend/constants.py b/src/backend/constants.py index 4f705bb..590b31a 100644 --- a/src/backend/constants.py +++ b/src/backend/constants.py @@ -1,14 +1,9 @@ +import os from enum import Enum -GPT4_MODEL = "gpt-4o" -GPT3_MODEL = "gpt-3.5-turbo" -LLAMA_8B_MODEL = "llama3-8b-8192" -LLAMA_70B_MODEL = "llama3-70b-8192" +from dotenv import load_dotenv -LOCAL_LLAMA3_MODEL = "llama3" -LOCAL_GEMMA_MODEL = "gemma:7b" -LOCAL_MISTRAL_MODEL = "mistral" -LOCAL_PHI3_14B = "phi3:14b" +load_dotenv() class ChatModel(str, Enum): @@ -24,11 +19,22 @@ class ChatModel(str, Enum): model_mappings: dict[ChatModel, str] = { - ChatModel.GPT_3_5_TURBO: GPT3_MODEL, - ChatModel.GPT_4o: GPT4_MODEL, - ChatModel.LLAMA_3_70B: LLAMA_70B_MODEL, - ChatModel.LOCAL_LLAMA_3: LOCAL_LLAMA3_MODEL, - ChatModel.LOCAL_GEMMA: LOCAL_GEMMA_MODEL, - ChatModel.LOCAL_MISTRAL: LOCAL_MISTRAL_MODEL, - ChatModel.LOCAL_PHI3_14B: LOCAL_PHI3_14B, + ChatModel.GPT_3_5_TURBO: "gpt-3.5-turbo", + ChatModel.GPT_4o: "gpt-4o", + ChatModel.LLAMA_3_70B: "groq/llama3-70b-8192", + ChatModel.LOCAL_LLAMA_3: "ollama_chat/llama3", + ChatModel.LOCAL_GEMMA: "ollama_chat/gemma", + ChatModel.LOCAL_MISTRAL: "ollama_chat/mistral", + ChatModel.LOCAL_PHI3_14B: "ollama_chat/phi3:14b", } + + +def get_model_string(model: ChatModel) -> str: + if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}: + openai_mode = os.environ.get("OPENAI_MODE", "openai") + if openai_mode == "azure": + # Currently deployments are named "gpt-35-turbo" and "gpt-4o" + name = model_mappings[model].replace(".", "") + return f"azure/{name}" + + return model_mappings[model] diff --git a/src/backend/llm/base.py b/src/backend/llm/base.py new file mode 100644 index 0000000..dc31f07 --- /dev/null +++ b/src/backend/llm/base.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod + +import instructor +from dotenv import load_dotenv +from instructor.client import T +from litellm import completion +from llama_index.core.base.llms.types import ( + CompletionResponse, + CompletionResponseAsyncGen, +) +from llama_index.llms.litellm import LiteLLM + +load_dotenv() + + +class BaseLLM(ABC): + @abstractmethod + async def astream(self, prompt: str) -> CompletionResponseAsyncGen: + pass + + @abstractmethod + def complete(self, prompt: str) -> CompletionResponse: + pass + + @abstractmethod + def structured_complete(self, response_model: type[T], prompt: str) -> T: + pass + + +class EveryLLM(BaseLLM): + def __init__( + self, + model: str, + ): + self.llm = LiteLLM(model=model) + + self.client = instructor.from_litellm(completion) + + async def astream(self, prompt: str) -> CompletionResponseAsyncGen: + return await self.llm.astream_complete(prompt) + + def complete(self, prompt: str) -> CompletionResponse: + return self.llm.complete(prompt) + + def structured_complete(self, response_model: type[T], prompt: str) -> T: + return self.client.chat.completions.create( + model=self.llm.model, + messages=[{"role": "user", "content": prompt}], + response_model=response_model, + ) diff --git a/src/backend/related_queries.py b/src/backend/related_queries.py index 3312391..a0a920c 100644 --- a/src/backend/related_queries.py +++ b/src/backend/related_queries.py @@ -1,85 +1,15 @@ -import os - -import groq -import instructor -import openai -from dotenv import load_dotenv - -from backend.constants import ChatModel, model_mappings +from backend.llm.base import BaseLLM from backend.prompts import RELATED_QUESTION_PROMPT from backend.schemas import RelatedQueries, SearchResult -load_dotenv() - - -OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://localhost:11434") - - -def get_openai_client() -> openai.AsyncOpenAI: - openai_mode = os.environ.get("OPENAI_MODE", "openai") - if openai_mode == "openai": - return openai.AsyncOpenAI() - elif openai_mode == "azure": - return openai.AsyncAzureOpenAI( - azure_deployment=os.environ.get("AZURE_DEPLOYMENT_NAME"), - azure_endpoint=os.environ["AZURE_CHAT_ENDPOINT"], - api_key=os.environ.get("AZURE_API_KEY"), - api_version="2024-04-01-preview", - ) - else: - raise ValueError(f"Unknown openai mode: {openai_mode}") - - -def instructor_client(model: ChatModel) -> instructor.AsyncInstructor: - if model == ChatModel.GPT_3_5_TURBO: - return instructor.from_openai( - get_openai_client(), - ) - elif model in [ - ChatModel.GPT_3_5_TURBO, - ChatModel.GPT_4o, - ]: - return instructor.from_openai(openai.AsyncOpenAI()) - elif model in [ - ChatModel.LOCAL_GEMMA, - ChatModel.LOCAL_LLAMA_3, - ChatModel.LOCAL_MISTRAL, - ChatModel.LOCAL_PHI3_14B, - ]: - return instructor.from_openai( - openai.AsyncOpenAI( - base_url=f"{OLLAMA_HOST}/v1", - api_key="ollama", - ), - mode=instructor.Mode.JSON, - ) - elif model == ChatModel.LLAMA_3_70B: - return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON) # type: ignore - else: - raise ValueError(f"Unknown model: {model}") - async def generate_related_queries( - query: str, search_results: list[SearchResult], model: ChatModel + query: str, search_results: list[SearchResult], llm: BaseLLM ) -> list[str]: context = "\n\n".join([f"{str(result)}" for result in search_results]) - # Truncate the context to 4000 characters (mainly for smaller models) context = context[:4000] - - client = instructor_client(model) - model_name = model_mappings[model] - - print(RELATED_QUESTION_PROMPT.format(query=query, context=context)) - - related = await client.chat.completions.create( - model=model_name, - response_model=RelatedQueries, - messages=[ - { - "role": "user", - "content": RELATED_QUESTION_PROMPT.format(query=query, context=context), - }, - ], + related = llm.structured_complete( + RelatedQueries, RELATED_QUESTION_PROMPT.format(query=query, context=context) ) return [query.lower().replace("?", "") for query in related.related_questions]