diff --git a/src/backend/chat.py b/src/backend/chat.py index 24981da..257e73c 100644 --- a/src/backend/chat.py +++ b/src/backend/chat.py @@ -4,6 +4,7 @@ from fastapi import HTTPException from llama_index.core.llms import LLM +from llama_index.llms.azure_openai import AzureOpenAI from llama_index.llms.groq import Groq from llama_index.llms.ollama import Ollama from llama_index.llms.openai import OpenAI @@ -45,8 +46,25 @@ def rephrase_query_with_history(question: str, history: List[Message], llm: LLM) ) +def get_openai_model(model: ChatModel) -> LLM: + openai_mode = os.environ.get("OPENAI_MODE", "openai") + if openai_mode == "azure": + return AzureOpenAI( + deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME"), + api_key=os.environ.get("AZURE_API_KEY"), + azure_endpoint=os.environ.get("AZURE_CHAT_ENDPOINT"), + api_version="2024-04-01-preview", + ) + elif openai_mode == "openai": + return OpenAI(model=model_mappings[model]) + else: + raise ValueError(f"Unknown model: {model}") + + def get_llm(model: ChatModel) -> LLM: - if model in [ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o]: + if model == ChatModel.GPT_3_5_TURBO: + return get_openai_model(model) + elif model == ChatModel.GPT_4o: return OpenAI(model=model_mappings[model]) elif model in [ ChatModel.LOCAL_GEMMA, diff --git a/src/backend/prompts.py b/src/backend/prompts.py index ee95a18..a49956c 100644 --- a/src/backend/prompts.py +++ b/src/backend/prompts.py @@ -27,10 +27,10 @@ RELATED_QUESTION_PROMPT = """\ Given a question and search result context, generate 3 follow-up questions the user might ask. Use the original question and context. -There must be EXACTLY 3 questions. Keep the questions concise, and simple. This should return an object with the following fields: - -questions: A list of 3 concise, simple questions - +Instructions: +- Generate exactly 3 questions. +- These questions should be concise, and simple. +- Ensure the follow-up questions are relevant to the original question and context. Make sure to match the language of the user's question. Original Question: {query} @@ -38,7 +38,8 @@ {context} -Your EXACTLY 3 (three) follow-up questions: +Output: +related_questions: A list of EXACTLY three concise, simple follow-up questions """ HISTORY_QUERY_REPHRASE = """ diff --git a/src/backend/related_queries.py b/src/backend/related_queries.py index 09f590e..3312391 100644 --- a/src/backend/related_queries.py +++ b/src/backend/related_queries.py @@ -15,8 +15,27 @@ OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://localhost:11434") +def get_openai_client() -> openai.AsyncOpenAI: + openai_mode = os.environ.get("OPENAI_MODE", "openai") + if openai_mode == "openai": + return openai.AsyncOpenAI() + elif openai_mode == "azure": + return openai.AsyncAzureOpenAI( + azure_deployment=os.environ.get("AZURE_DEPLOYMENT_NAME"), + azure_endpoint=os.environ["AZURE_CHAT_ENDPOINT"], + api_key=os.environ.get("AZURE_API_KEY"), + api_version="2024-04-01-preview", + ) + else: + raise ValueError(f"Unknown openai mode: {openai_mode}") + + def instructor_client(model: ChatModel) -> instructor.AsyncInstructor: - if model in [ + if model == ChatModel.GPT_3_5_TURBO: + return instructor.from_openai( + get_openai_client(), + ) + elif model in [ ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o, ]: @@ -35,7 +54,7 @@ def instructor_client(model: ChatModel) -> instructor.AsyncInstructor: mode=instructor.Mode.JSON, ) elif model == ChatModel.LLAMA_3_70B: - return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON) + return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON) # type: ignore else: raise ValueError(f"Unknown model: {model}") @@ -50,6 +69,8 @@ async def generate_related_queries( client = instructor_client(model) model_name = model_mappings[model] + print(RELATED_QUESTION_PROMPT.format(query=query, context=context)) + related = await client.chat.completions.create( model=model_name, response_model=RelatedQueries, @@ -61,4 +82,4 @@ async def generate_related_queries( ], ) - return [query.lower().replace("?", "") for query in related.questions] + return [query.lower().replace("?", "") for query in related.related_questions] diff --git a/src/backend/schemas.py b/src/backend/schemas.py index 7b3f7ef..89d9406 100644 --- a/src/backend/schemas.py +++ b/src/backend/schemas.py @@ -37,7 +37,7 @@ class ChatRequest(BaseModel, plugin_settings=record_all): class RelatedQueries(BaseModel): - questions: List[str] = Field(..., min_length=3, max_length=3) + related_questions: List[str] = Field(..., min_length=3, max_length=3) class SearchResult(BaseModel): diff --git a/src/frontend/src/components/model-selection.tsx b/src/frontend/src/components/model-selection.tsx index d0d1da3..bb17f42 100644 --- a/src/frontend/src/components/model-selection.tsx +++ b/src/frontend/src/components/model-selection.tsx @@ -111,7 +111,7 @@ const ModelItem: React.FC<{ model: Model }> = ({ model }) => ( ); export function ModelSelection() { - const { model, setModel, localMode } = useConfigStore(); + const { localMode, model, setModel } = useConfigStore(); return (