Skip to content

Commit

Permalink
Chat UI (#50)
Browse files Browse the repository at this point in the history
* Update app, add sidebar, relovace methods to utils

* Update poetry files.

* Add const variable and move methods from app.

* Made minimal requirements for deployment.

* Add comments.

* Change method names.
  • Loading branch information
milistu authored Jun 1, 2024
1 parent 1eba628 commit 82ce077
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 256 deletions.
134 changes: 44 additions & 90 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,73 @@
import json
import os
from pathlib import Path

import streamlit as st
import yaml
from dotenv import find_dotenv, load_dotenv
from langfuse.decorators import observe
from loguru import logger
from openai import OpenAI
from qdrant_client import QdrantClient

from database.utils import embed_text, get_context, search
from llm.prompts import DEFAULT_CONTEXT, INTRODUCTION_MESSAGE
from llm.utils import get_answer, get_messages
from router.query_router import semantic_query_router
from router.router_prompt import DEFAULT_ROUTER_RESPONSE, ROUTER_PROMPT

load_dotenv(find_dotenv())

st.title("Legal ChatBot")

qdrant_client = QdrantClient(
url=os.environ["QDRANT_CLUSTER_URL"],
api_key=os.environ["QDRANT_API_KEY"],
from llm.prompts import INTRODUCTION_MESSAGE
from utils import (
generate_response,
initialize_clients,
load_config,
WARNING_MESSAGE,
QUERY_SUGGESTIONS,
AUTHORS,
LOGO_URL,
)

openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

config_path = Path("./config.yaml")
centroid_path = Path("./router/collection_centroids.json")

with config_path.open("r") as file:
config = yaml.safe_load(file)

# Delete this
with open(centroid_path, "r", encoding="utf-8") as file:
centroids = json.loads(file.read())


@observe()
def response_generator(query: str):
st.session_state.messages = st.session_state.messages[
-1 * config["openai"]["gpt_model"]["max_conversation"] :
]

# st.session_state.messages.append({"role": "user", "content": query})
# Load environment variables from the .env file.
load_dotenv(find_dotenv())

embedding_response = embed_text(
client=openai_client,
text=query,
model=config["openai"]["embedding_model"]["name"],
)
embedding = embedding_response.data[0].embedding

# Rout query
collections = semantic_query_router(
client=openai_client,
model=config["openai"]["gpt_model"]["router"],
query=query,
prompt=ROUTER_PROMPT,
temperature=config["openai"]["gpt_model"]["temperature"],
)
logger.info(f"Query routed to collections: {collections}")
# Set Streamlit page configuration with custom title and icon.
st.set_page_config(page_title="Your Lawyer Assistant", page_icon=LOGO_URL)
st.title("LegaBot")
st.divider()

if collections[0] == DEFAULT_ROUTER_RESPONSE:
context = DEFAULT_CONTEXT
else:
search_results = []
for collection_name in collections:
search_results.extend(
search(
client=qdrant_client,
collection=collection_name,
query_vector=embedding,
limit=10,
with_vectors=True,
)
)
# Initialize API clients for OpenAI and Qdrant and load configuration settings.
openai_client, qdrant_client = initialize_clients()
config = load_config()

top_k = 15 if len(collections) > 1 else 10
context = get_context(search_results=search_results, top_k=top_k)
# Display the logo and set up the sidebar with useful information and links.
st.logo(LOGO_URL, icon_image=LOGO_URL)
with st.sidebar:
st.subheader("💡 Query Suggestions")
with st.container(border=True, height=200):
st.markdown(QUERY_SUGGESTIONS)

stream = get_answer(
client=openai_client,
model=config["openai"]["gpt_model"]["llm"],
temperature=config["openai"]["gpt_model"]["temperature"],
messages=get_messages(
context=context, query=query, conversation=st.session_state.messages
),
stream=True,
)
st.subheader("⚠️ Warning")
with st.container(border=True):
st.markdown(WARNING_MESSAGE)

for chunk in stream:
part = chunk.choices[0].delta.content
if part is not None:
yield part
st.subheader("✍️ Authors")
st.markdown(AUTHORS)


# Initialize or update the session state for storing chat messages.
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": INTRODUCTION_MESSAGE}]

# Display all chat messages stored in the session state.
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])


# Handle user input and generate responses.
if prompt := st.chat_input("Postavi pitanje iz prava..."):
# Generate and display the response
# Append user message to session state.
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container

# Display user message in chat container.
with st.chat_message("user"):
st.markdown(prompt)

with st.chat_message("assistant"):
stream = response_generator(prompt)
# Generate a response using the LLM and display it as a stream.
stream = generate_response(
query=prompt,
openai_client=openai_client,
qdrant_client=qdrant_client,
config=config,
)
# Write the response stream to the chat.
response = st.write_stream(stream)

# Append assistant's response to session state.
st.session_state.messages.append({"role": "assistant", "content": response})
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.poetry]
name = "sais-hackathon"
name = "LegaBot"
version = "0.1.0"
description = ""
authors = ["Milutin-S <studen.m98@gmail.com>"]
Expand Down
166 changes: 9 additions & 157 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,157 +1,9 @@
aiohttp==3.9.3 ; python_version >= "3.10" and python_version < "4.0"
aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "4.0"
altair==5.3.0 ; python_version >= "3.10" and python_version < "4.0"
annotated-types==0.6.0 ; python_version >= "3.10" and python_version < "4.0"
anyio==4.3.0 ; python_version >= "3.10" and python_version < "4.0"
appnope==0.1.4 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Darwin"
argon2-cffi-bindings==21.2.0 ; python_version >= "3.10" and python_version < "4.0"
argon2-cffi==23.1.0 ; python_version >= "3.10" and python_version < "4.0"
arrow==1.3.0 ; python_version >= "3.10" and python_version < "4.0"
asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0"
async-lru==2.0.4 ; python_version >= "3.10" and python_version < "4.0"
async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
attrs==23.2.0 ; python_version >= "3.10" and python_version < "4.0"
babel==2.14.0 ; python_version >= "3.10" and python_version < "4.0"
backoff==2.2.1 ; python_version >= "3.10" and python_version < "4.0"
beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "4.0"
bleach==6.1.0 ; python_version >= "3.10" and python_version < "4.0"
blinker==1.7.0 ; python_version >= "3.10" and python_version < "4.0"
cachetools==5.3.3 ; python_version >= "3.10" and python_version < "4.0"
certifi==2024.2.2 ; python_version >= "3.10" and python_version < "4.0"
cffi==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "4.0"
click==8.1.7 ; python_version >= "3.10" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
comm==0.2.2 ; python_version >= "3.10" and python_version < "4.0"
debugpy==1.8.1 ; python_version >= "3.10" and python_version < "4.0"
decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0"
defusedxml==0.7.1 ; python_version >= "3.10" and python_version < "4.0"
distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0"
exceptiongroup==1.2.0 ; python_version >= "3.10" and python_version < "3.11"
executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0"
fastjsonschema==2.19.1 ; python_version >= "3.10" and python_version < "4.0"
fqdn==1.5.1 ; python_version >= "3.10" and python_version < "4"
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "4.0"
gitdb==4.0.11 ; python_version >= "3.10" and python_version < "4.0"
gitpython==3.1.43 ; python_version >= "3.10" and python_version < "4.0"
grpcio-tools==1.62.1 ; python_version >= "3.10" and python_version < "4.0"
grpcio==1.62.1 ; python_version >= "3.10" and python_version < "4.0"
h11==0.14.0 ; python_version >= "3.10" and python_version < "4.0"
h2==4.1.0 ; python_version >= "3.10" and python_version < "4.0"
hpack==4.0.0 ; python_version >= "3.10" and python_version < "4.0"
httpcore==1.0.4 ; python_version >= "3.10" and python_version < "4.0"
httpx==0.27.0 ; python_version >= "3.10" and python_version < "4.0"
httpx[http2]==0.27.0 ; python_version >= "3.10" and python_version < "4.0"
hyperframe==6.0.1 ; python_version >= "3.10" and python_version < "4.0"
idna==3.7 ; python_version >= "3.10" and python_version < "4.0"
ipykernel==6.29.3 ; python_version >= "3.10" and python_version < "4.0"
ipython==8.22.2 ; python_version >= "3.10" and python_version < "4.0"
ipywidgets==8.1.2 ; python_version >= "3.10" and python_version < "4.0"
isoduration==20.11.0 ; python_version >= "3.10" and python_version < "4.0"
jedi==0.19.1 ; python_version >= "3.10" and python_version < "4.0"
jinja2==3.1.3 ; python_version >= "3.10" and python_version < "4.0"
json5==0.9.24 ; python_version >= "3.10" and python_version < "4.0"
jsonpointer==2.4 ; python_version >= "3.10" and python_version < "4.0"
jsonschema-specifications==2023.12.1 ; python_version >= "3.10" and python_version < "4.0"
jsonschema==4.21.1 ; python_version >= "3.10" and python_version < "4.0"
jsonschema[format-nongpl]==4.21.1 ; python_version >= "3.10" and python_version < "4.0"
jupyter-client==8.6.1 ; python_version >= "3.10" and python_version < "4.0"
jupyter-console==6.6.3 ; python_version >= "3.10" and python_version < "4.0"
jupyter-core==5.7.2 ; python_version >= "3.10" and python_version < "4.0"
jupyter-events==0.10.0 ; python_version >= "3.10" and python_version < "4.0"
jupyter-lsp==2.2.4 ; python_version >= "3.10" and python_version < "4.0"
jupyter-server-terminals==0.5.3 ; python_version >= "3.10" and python_version < "4.0"
jupyter-server==2.13.0 ; python_version >= "3.10" and python_version < "4.0"
jupyter==1.0.0 ; python_version >= "3.10" and python_version < "4.0"
jupyterlab-pygments==0.3.0 ; python_version >= "3.10" and python_version < "4.0"
jupyterlab-server==2.25.4 ; python_version >= "3.10" and python_version < "4.0"
jupyterlab-widgets==3.0.10 ; python_version >= "3.10" and python_version < "4.0"
jupyterlab==4.1.5 ; python_version >= "3.10" and python_version < "4.0"
langfuse==2.27.2 ; python_version >= "3.10" and python_version < "4.0"
loguru==0.7.2 ; python_version >= "3.10" and python_version < "4.0"
lxml==5.1.0 ; python_version >= "3.10" and python_version < "4.0"
markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "4.0"
markupsafe==2.1.5 ; python_version >= "3.10" and python_version < "4.0"
matplotlib-inline==0.1.6 ; python_version >= "3.10" and python_version < "4.0"
mdurl==0.1.2 ; python_version >= "3.10" and python_version < "4.0"
mistune==3.0.2 ; python_version >= "3.10" and python_version < "4.0"
multidict==6.0.5 ; python_version >= "3.10" and python_version < "4.0"
nbclient==0.10.0 ; python_version >= "3.10" and python_version < "4.0"
nbconvert==7.16.2 ; python_version >= "3.10" and python_version < "4.0"
nbformat==5.10.3 ; python_version >= "3.10" and python_version < "4.0"
nest-asyncio==1.6.0 ; python_version >= "3.10" and python_version < "4.0"
notebook-shim==0.2.4 ; python_version >= "3.10" and python_version < "4.0"
notebook==7.1.2 ; python_version >= "3.10" and python_version < "4.0"
numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0"
openai==1.14.2 ; python_version >= "3.10" and python_version < "4.0"
overrides==7.7.0 ; python_version >= "3.10" and python_version < "4.0"
packaging==23.2 ; python_version >= "3.10" and python_version < "4.0"
pandas==2.2.0 ; python_version >= "3.10" and python_version < "4.0"
pandocfilters==1.5.1 ; python_version >= "3.10" and python_version < "4.0"
parso==0.8.3 ; python_version >= "3.10" and python_version < "4.0"
pexpect==4.9.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten")
pillow==10.2.0 ; python_version >= "3.10" and python_version < "4.0"
platformdirs==4.2.0 ; python_version >= "3.10" and python_version < "4.0"
portalocker==2.8.2 ; python_version >= "3.10" and python_version < "4.0"
prometheus-client==0.20.0 ; python_version >= "3.10" and python_version < "4.0"
prompt-toolkit==3.0.43 ; python_version >= "3.10" and python_version < "4.0"
protobuf==4.25.3 ; python_version >= "3.10" and python_version < "4.0"
psutil==5.9.8 ; python_version >= "3.10" and python_version < "4.0"
ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten" or os_name != "nt")
pure-eval==0.2.2 ; python_version >= "3.10" and python_version < "4.0"
pyarrow==15.0.2 ; python_version >= "3.10" and python_version < "4.0"
pycparser==2.21 ; python_version >= "3.10" and python_version < "4.0"
pydantic-core==2.16.3 ; python_version >= "3.10" and python_version < "4.0"
pydantic==2.6.4 ; python_version >= "3.10" and python_version < "4.0"
pydeck==0.8.0 ; python_version >= "3.10" and python_version < "4.0"
pygments==2.17.2 ; python_version >= "3.10" and python_version < "4.0"
python-dateutil==2.9.0.post0 ; python_version >= "3.10" and python_version < "4.0"
python-dotenv==1.0.1 ; python_version >= "3.10" and python_version < "4.0"
python-json-logger==2.0.7 ; python_version >= "3.10" and python_version < "4.0"
pytz==2023.3.post1 ; python_version >= "3.10" and python_version < "4.0"
pywin32==306 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32") and (platform_system == "Windows" or platform_python_implementation != "PyPy")
pywinpty==2.0.13 ; python_version >= "3.10" and python_version < "4.0" and os_name == "nt"
pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "4.0"
pyzmq==25.1.2 ; python_version >= "3.10" and python_version < "4.0"
qdrant-client==1.8.0 ; python_version >= "3.10" and python_version < "4.0"
qtconsole==5.5.1 ; python_version >= "3.10" and python_version < "4.0"
qtpy==2.4.1 ; python_version >= "3.10" and python_version < "4.0"
referencing==0.34.0 ; python_version >= "3.10" and python_version < "4.0"
regex==2023.12.25 ; python_version >= "3.10" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.10" and python_version < "4.0"
rfc3339-validator==0.1.4 ; python_version >= "3.10" and python_version < "4.0"
rfc3986-validator==0.1.1 ; python_version >= "3.10" and python_version < "4.0"
rich==13.7.1 ; python_version >= "3.10" and python_version < "4.0"
rpds-py==0.18.0 ; python_version >= "3.10" and python_version < "4.0"
send2trash==1.8.2 ; python_version >= "3.10" and python_version < "4.0"
setuptools==69.2.0 ; python_version >= "3.10" and python_version < "4.0"
six==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
smmap==5.0.1 ; python_version >= "3.10" and python_version < "4.0"
sniffio==1.3.1 ; python_version >= "3.10" and python_version < "4.0"
soupsieve==2.5 ; python_version >= "3.10" and python_version < "4.0"
stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0"
streamlit==1.32.2 ; python_version >= "3.10" and python_version < "4.0"
tenacity==8.2.3 ; python_version >= "3.10" and python_version < "4.0"
terminado==0.18.1 ; python_version >= "3.10" and python_version < "4.0"
tiktoken==0.6.0 ; python_version >= "3.10" and python_version < "4.0"
tinycss2==1.2.1 ; python_version >= "3.10" and python_version < "4.0"
toml==0.10.2 ; python_version >= "3.10" and python_version < "4.0"
tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
toolz==0.12.1 ; python_version >= "3.10" and python_version < "4.0"
tornado==6.4 ; python_version >= "3.10" and python_version < "4.0"
tqdm==4.66.2 ; python_version >= "3.10" and python_version < "4.0"
traitlets==5.14.2 ; python_version >= "3.10" and python_version < "4.0"
types-python-dateutil==2.9.0.20240316 ; python_version >= "3.10" and python_version < "4.0"
typing-extensions==4.10.0 ; python_version >= "3.10" and python_version < "4.0"
tzdata==2024.1 ; python_version >= "3.10" and python_version < "4.0"
uri-template==1.3.0 ; python_version >= "3.10" and python_version < "4.0"
urllib3==2.2.1 ; python_version >= "3.10" and python_version < "4.0"
watchdog==4.0.0 ; python_version >= "3.10" and python_version < "4.0" and platform_system != "Darwin"
wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4.0"
webcolors==1.13 ; python_version >= "3.10" and python_version < "4.0"
webencodings==0.5.1 ; python_version >= "3.10" and python_version < "4.0"
websocket-client==1.7.0 ; python_version >= "3.10" and python_version < "4.0"
widgetsnbextension==4.0.10 ; python_version >= "3.10" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.10" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
yarl==1.9.4 ; python_version >= "3.10" and python_version < "4.0"
qdrant-client==1.8.0
python-dotenv==1.0.1
tiktoken==0.6.0
openai==1.14.2
loguru==0.7.2
pyyaml==6.0.1
streamlit==1.32.2
backoff==2.2.1
langfuse==2.27.2
4 changes: 2 additions & 2 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

from router.query_router import semantic_query_router
from router.router_prompt import ROUTER_PROMPT
from utils import load_config_from_yaml
from utils import load_config


class RouterTest(unittest.TestCase):

def setUp(self) -> None:
# Load configuration
config_path = Path("./config.yaml")
self.config = load_config_from_yaml(yaml_file_path=config_path)
self.config = load_config(yaml_file_path=config_path)

# Initialize OpenAI client
self.openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
Expand Down
Loading

0 comments on commit 82ce077

Please sign in to comment.