diff --git a/app.py b/app.py index 5b489e7..7dd85c9 100644 --- a/app.py +++ b/app.py @@ -23,7 +23,7 @@ st.divider() # Initialize API clients for OpenAI and Qdrant and load configuration settings. -openai_client, qdrant_client = initialize_clients() +qdrant_client = initialize_clients() config = load_config() # Display the logo and set up the sidebar with useful information and links. @@ -63,7 +63,6 @@ # Generate a response using the LLM and display it as a stream. stream = generate_response( query=prompt, - openai_client=openai_client, qdrant_client=qdrant_client, config=config, ) diff --git a/config.yaml b/config.yaml index 309b332..ec9ad4f 100644 --- a/config.yaml +++ b/config.yaml @@ -3,7 +3,7 @@ openai: model: "text-embedding-3-small" dimensions: 1536 chat: - model: "gpt-4-turbo-preview" + model: "gpt-4o" temperature: 0 max_conversation: 100 router: diff --git a/database/utils.py b/database/utils.py index b6f8e6d..4227196 100644 --- a/database/utils.py +++ b/database/utils.py @@ -5,8 +5,9 @@ import numpy as np import tiktoken +from langfuse.decorators import observe +from langfuse.openai import openai from loguru import logger -from openai import OpenAI from openai.types import CreateEmbeddingResponse from qdrant_client import QdrantClient from qdrant_client.http.models import ( @@ -86,13 +87,12 @@ def search( ) -def embed_text( - client: OpenAI, text: Union[str, list], model: str -) -> CreateEmbeddingResponse: +@observe() +def embed_text(text: Union[str, list], model: str) -> CreateEmbeddingResponse: """ Create embeddings using OpenAI API. """ - response = client.embeddings.create(input=text, model=model) + response = openai.embeddings.create(input=text, model=model) return response diff --git a/llm/utils.py b/llm/utils.py index 4a2ea54..aabc76b 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -1,9 +1,5 @@ from typing import Dict, List -from langfuse.decorators import observe -from openai import OpenAI -from openai.types.chat import ChatCompletion - from llm.prompts import ( CONTEXT_PROMPT, CONVERSATION_PROMPT, @@ -12,33 +8,9 @@ ) -@observe() -def get_answer( - client: OpenAI, - model: str, - temperature: float, - messages: List[Dict], - stream: bool = False, -) -> ChatCompletion: - """ - Get an answer from the OpenAI chat model. - - Args: - client (OpenAI): The OpenAI client instance. - model (str): The model name to use. - temperature (float): The temperature setting for the model. - messages (List[Dict]): The list of messages to send to the model. - stream (bool, optional): Whether to stream the response. Defaults to False. - - Returns: - ChatCompletion: The chat completion response from OpenAI. - """ - return client.chat.completions.create( - model=model, temperature=temperature, messages=messages, stream=stream - ) - - -def get_messages(context: str, query: str, conversation: List[str]) -> List[Dict]: +def formate_messages_chat( + context: str, query: str, conversation: List[str] +) -> List[Dict]: """ Prepare the list of messages for the chat model. diff --git a/router/query_router.py b/router/query_router.py index df9a86a..7942547 100644 --- a/router/query_router.py +++ b/router/query_router.py @@ -1,38 +1,21 @@ -import json -from typing import List +from typing import Dict, List -from langfuse.decorators import observe -from openai import OpenAI +from router.router_prompt import ROUTER_PROMPT, USER_QUERY -@observe() -def semantic_query_router( - client: OpenAI, +def formate_messages_router( query: str, - prompt: str, - temperature: float, - model: str = "gpt-3.5-turbo", -) -> List[str]: +) -> List[Dict]: """ - Routes a semantic query to the appropriate collections using OpenAI's API. + Prepare the list of messages for the llm model. Args: - client (OpenAI): The OpenAI client instance. - query (str): The query string to be routed. - prompt (str): The prompt template to be used for the query. - temperature (float): The temperature setting for the model's response. - model (str, optional): The model to be used. Defaults to "gpt-3.5-turbo". + query (str): The user's query. Returns: - List[str]: A list of collections that are relevant to the query. + List[Dict]: The list of messages formatted for the llm model. """ - # Create the completion request to the OpenAI API - response = client.chat.completions.create( - model=model, - response_format={"type": "json_object"}, - messages=[{"role": "system", "content": prompt.format(query=query)}], - temperature=temperature, - ) - # Parse the response to extract the collections - collections = json.loads(response.choices[0].message.content)["response"] - return collections + return [ + {"role": "system", "content": ROUTER_PROMPT}, + {"role": "user", "content": USER_QUERY.format(query=query)}, + ] diff --git a/router/router_prompt.py b/router/router_prompt.py index 0a43b53..3be9318 100644 --- a/router/router_prompt.py +++ b/router/router_prompt.py @@ -1,4 +1,5 @@ ROUTER_PROMPT = """ +**INSTRUKCIJE:** Tvoj zadatak je da na osnovu datog pitanja korisnika odlucis koji zakon ili zakoni su potrebni da bi se odgovorilo na korisnikovo pitanje. Ponudjeni zakoni i njihova objasnjenja su sledeci: - zakon_o_radu @@ -20,12 +21,14 @@ - Jedno pitanje korisnika moze da se odnosi na vise zakona. - Vrati zakone koji mogu da pomognu prilikom generisanja odgovora. - Ukoliko korisnikovo pitanje ne odgovara ni jednom zakonu vrati listu sa generickim stringom: ["nema_zakona"]. -- Primer JSON odgovora: +**PRIMER ODGOVORA:** {{ response: ["ime_zakona"] }} +""" +USER_QUERY = """ **PITANJE KORISINKA:** {query} """ diff --git a/utils.py b/utils.py index c012413..e521a61 100644 --- a/utils.py +++ b/utils.py @@ -1,19 +1,21 @@ +import json import os -from typing import Generator, List, Tuple +from typing import Dict, Generator, List import streamlit as st import yaml -from langfuse.decorators import observe +from langfuse.decorators import langfuse_context, observe +from langfuse.openai import openai from loguru import logger -from openai import OpenAI +from openai.types.chat import ChatCompletion from pydantic import BaseModel from qdrant_client import QdrantClient from database.utils import embed_text, get_context, search from llm.prompts import DEFAULT_CONTEXT -from llm.utils import get_answer, get_messages -from router.query_router import semantic_query_router -from router.router_prompt import DEFAULT_ROUTER_RESPONSE, ROUTER_PROMPT +from llm.utils import formate_messages_chat +from router.query_router import formate_messages_router +from router.router_prompt import DEFAULT_ROUTER_RESPONSE LOGO_URL = "assets/Legabot-Logomark.svg" LOGO_TEXT_URL = "assets/Legabot-Light-Horizontal.svg" @@ -71,7 +73,7 @@ def load_config(yaml_file_path: str = "./config.yaml") -> Config: @st.cache_resource -def initialize_clients() -> Tuple[OpenAI, QdrantClient]: +def initialize_clients() -> QdrantClient: """ Initializes and returns the clients for OpenAI and Qdrant services. @@ -87,27 +89,51 @@ def initialize_clients() -> Tuple[OpenAI, QdrantClient]: qdrant_api_key = os.environ["QDRANT_API_KEY"] qdrant_client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) - # Retrieve OpenAI client configuration from environment variables - openai_api_key = os.environ["OPENAI_API_KEY"] - openai_client = OpenAI(api_key=openai_api_key) - - return openai_client, qdrant_client + return qdrant_client except KeyError as e: error_msg = f"Missing environment variable: {str(e)}" logger.error(error_msg) raise EnvironmentError(error_msg) +@observe(as_type="generation") +def call_llm( + model: str, + temperature: float, + messages: List[Dict], + json_response: bool = False, + stream: bool = False, +) -> ChatCompletion: + """ + Get an answer from the OpenAI chat model. + + Args: + model (str): The model name to use. + temperature (float): The temperature setting for the model. + messages (List[Dict]): The list of messages to send to the model. + stream (bool, optional): Whether to stream the response. Defaults to False. + + Returns: + ChatCompletion: The chat completion response from OpenAI. + """ + return openai.chat.completions.create( + model=model, + response_format={"type": "json_object"} if json_response else None, + temperature=temperature, + messages=messages, + stream=stream, + ) + + @observe() def generate_response( - query: str, openai_client: OpenAI, qdrant_client: QdrantClient, config: Config + query: str, qdrant_client: QdrantClient, config: Config ) -> Generator[str, None, None]: """ Generates a response for a given user query using a combination of semantic search and a chat model. Args: - query (str): The user's query string. - - openai_client (OpenAI): Client to interact with OpenAI's API. - qdrant_client (QdrantClient): Client to interact with Qdrant's API. - config (Config): Configuration settings for API interaction and response handling. @@ -120,35 +146,36 @@ def generate_response( -config.openai.chat.max_conversation : ] + # Determine the relevant collections to route the query to + messages = formate_messages_router(query) + response = call_llm( + model=config.openai.router.model, + temperature=config.openai.router.temperature, + messages=messages, + json_response=True, + ) + collections = json.loads(response.choices[0].message.content)["response"] + logger.info(f"Query routed to collections: {collections}") + langfuse_context.update_current_trace(tags=collections) + # Embed the user query using the specified model in the configuration embedding_response = embed_text( - client=openai_client, text=query, model=config.openai.embeddings.model, ) embedding = embedding_response.data[0].embedding - # Determine the relevant collections to route the query to - collections = semantic_query_router( - client=openai_client, - model=config.openai.router.model, - query=query, - prompt=ROUTER_PROMPT, - temperature=config.openai.router.temperature, - ) - logger.info(f"Query routed to collections: {collections}") - # Determine the context for the chat model based on the routed collections context = determine_context(collections, embedding, qdrant_client) # Generate the response stream from the chat model - stream = get_answer( - client=openai_client, + messages = formate_messages_chat( + context=context, query=query, conversation=st.session_state.messages + ) + stream = call_llm( model=config.openai.chat.model, temperature=config.openai.chat.temperature, - messages=get_messages( - context=context, query=query, conversation=st.session_state.messages - ), + messages=messages, stream=True, ) @@ -158,6 +185,8 @@ def generate_response( if part is not None: yield part + langfuse_context.flush() + except Exception as e: logger.error(f"An error occurred while generating the response: {str(e)}") yield "Sorry, an error occurred while processing your request."