-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update app, add sidebar, relovace methods to utils * Update poetry files. * Add const variable and move methods from app. * Made minimal requirements for deployment. * Add comments. * Change method names.
- Loading branch information
Showing
6 changed files
with
214 additions
and
256 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,119 +1,73 @@ | ||
import json | ||
import os | ||
from pathlib import Path | ||
|
||
import streamlit as st | ||
import yaml | ||
from dotenv import find_dotenv, load_dotenv | ||
from langfuse.decorators import observe | ||
from loguru import logger | ||
from openai import OpenAI | ||
from qdrant_client import QdrantClient | ||
|
||
from database.utils import embed_text, get_context, search | ||
from llm.prompts import DEFAULT_CONTEXT, INTRODUCTION_MESSAGE | ||
from llm.utils import get_answer, get_messages | ||
from router.query_router import semantic_query_router | ||
from router.router_prompt import DEFAULT_ROUTER_RESPONSE, ROUTER_PROMPT | ||
|
||
load_dotenv(find_dotenv()) | ||
|
||
st.title("Legal ChatBot") | ||
|
||
qdrant_client = QdrantClient( | ||
url=os.environ["QDRANT_CLUSTER_URL"], | ||
api_key=os.environ["QDRANT_API_KEY"], | ||
from llm.prompts import INTRODUCTION_MESSAGE | ||
from utils import ( | ||
generate_response, | ||
initialize_clients, | ||
load_config, | ||
WARNING_MESSAGE, | ||
QUERY_SUGGESTIONS, | ||
AUTHORS, | ||
LOGO_URL, | ||
) | ||
|
||
openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | ||
|
||
config_path = Path("./config.yaml") | ||
centroid_path = Path("./router/collection_centroids.json") | ||
|
||
with config_path.open("r") as file: | ||
config = yaml.safe_load(file) | ||
|
||
# Delete this | ||
with open(centroid_path, "r", encoding="utf-8") as file: | ||
centroids = json.loads(file.read()) | ||
|
||
|
||
@observe() | ||
def response_generator(query: str): | ||
st.session_state.messages = st.session_state.messages[ | ||
-1 * config["openai"]["gpt_model"]["max_conversation"] : | ||
] | ||
|
||
# st.session_state.messages.append({"role": "user", "content": query}) | ||
# Load environment variables from the .env file. | ||
load_dotenv(find_dotenv()) | ||
|
||
embedding_response = embed_text( | ||
client=openai_client, | ||
text=query, | ||
model=config["openai"]["embedding_model"]["name"], | ||
) | ||
embedding = embedding_response.data[0].embedding | ||
|
||
# Rout query | ||
collections = semantic_query_router( | ||
client=openai_client, | ||
model=config["openai"]["gpt_model"]["router"], | ||
query=query, | ||
prompt=ROUTER_PROMPT, | ||
temperature=config["openai"]["gpt_model"]["temperature"], | ||
) | ||
logger.info(f"Query routed to collections: {collections}") | ||
# Set Streamlit page configuration with custom title and icon. | ||
st.set_page_config(page_title="Your Lawyer Assistant", page_icon=LOGO_URL) | ||
st.title("LegaBot") | ||
st.divider() | ||
|
||
if collections[0] == DEFAULT_ROUTER_RESPONSE: | ||
context = DEFAULT_CONTEXT | ||
else: | ||
search_results = [] | ||
for collection_name in collections: | ||
search_results.extend( | ||
search( | ||
client=qdrant_client, | ||
collection=collection_name, | ||
query_vector=embedding, | ||
limit=10, | ||
with_vectors=True, | ||
) | ||
) | ||
# Initialize API clients for OpenAI and Qdrant and load configuration settings. | ||
openai_client, qdrant_client = initialize_clients() | ||
config = load_config() | ||
|
||
top_k = 15 if len(collections) > 1 else 10 | ||
context = get_context(search_results=search_results, top_k=top_k) | ||
# Display the logo and set up the sidebar with useful information and links. | ||
st.logo(LOGO_URL, icon_image=LOGO_URL) | ||
with st.sidebar: | ||
st.subheader("💡 Query Suggestions") | ||
with st.container(border=True, height=200): | ||
st.markdown(QUERY_SUGGESTIONS) | ||
|
||
stream = get_answer( | ||
client=openai_client, | ||
model=config["openai"]["gpt_model"]["llm"], | ||
temperature=config["openai"]["gpt_model"]["temperature"], | ||
messages=get_messages( | ||
context=context, query=query, conversation=st.session_state.messages | ||
), | ||
stream=True, | ||
) | ||
st.subheader("⚠️ Warning") | ||
with st.container(border=True): | ||
st.markdown(WARNING_MESSAGE) | ||
|
||
for chunk in stream: | ||
part = chunk.choices[0].delta.content | ||
if part is not None: | ||
yield part | ||
st.subheader("✍️ Authors") | ||
st.markdown(AUTHORS) | ||
|
||
|
||
# Initialize or update the session state for storing chat messages. | ||
if "messages" not in st.session_state: | ||
st.session_state.messages = [{"role": "assistant", "content": INTRODUCTION_MESSAGE}] | ||
|
||
# Display all chat messages stored in the session state. | ||
for message in st.session_state.messages: | ||
with st.chat_message(message["role"]): | ||
st.markdown(message["content"]) | ||
|
||
|
||
# Handle user input and generate responses. | ||
if prompt := st.chat_input("Postavi pitanje iz prava..."): | ||
# Generate and display the response | ||
# Append user message to session state. | ||
st.session_state.messages.append({"role": "user", "content": prompt}) | ||
# Display user message in chat message container | ||
|
||
# Display user message in chat container. | ||
with st.chat_message("user"): | ||
st.markdown(prompt) | ||
|
||
with st.chat_message("assistant"): | ||
stream = response_generator(prompt) | ||
# Generate a response using the LLM and display it as a stream. | ||
stream = generate_response( | ||
query=prompt, | ||
openai_client=openai_client, | ||
qdrant_client=qdrant_client, | ||
config=config, | ||
) | ||
# Write the response stream to the chat. | ||
response = st.write_stream(stream) | ||
|
||
# Append assistant's response to session state. | ||
st.session_state.messages.append({"role": "assistant", "content": response}) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,157 +1,9 @@ | ||
aiohttp==3.9.3 ; python_version >= "3.10" and python_version < "4.0" | ||
aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "4.0" | ||
altair==5.3.0 ; python_version >= "3.10" and python_version < "4.0" | ||
annotated-types==0.6.0 ; python_version >= "3.10" and python_version < "4.0" | ||
anyio==4.3.0 ; python_version >= "3.10" and python_version < "4.0" | ||
appnope==0.1.4 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Darwin" | ||
argon2-cffi-bindings==21.2.0 ; python_version >= "3.10" and python_version < "4.0" | ||
argon2-cffi==23.1.0 ; python_version >= "3.10" and python_version < "4.0" | ||
arrow==1.3.0 ; python_version >= "3.10" and python_version < "4.0" | ||
asttokens==2.4.1 ; python_version >= "3.10" and python_version < "4.0" | ||
async-lru==2.0.4 ; python_version >= "3.10" and python_version < "4.0" | ||
async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11" | ||
attrs==23.2.0 ; python_version >= "3.10" and python_version < "4.0" | ||
babel==2.14.0 ; python_version >= "3.10" and python_version < "4.0" | ||
backoff==2.2.1 ; python_version >= "3.10" and python_version < "4.0" | ||
beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "4.0" | ||
bleach==6.1.0 ; python_version >= "3.10" and python_version < "4.0" | ||
blinker==1.7.0 ; python_version >= "3.10" and python_version < "4.0" | ||
cachetools==5.3.3 ; python_version >= "3.10" and python_version < "4.0" | ||
certifi==2024.2.2 ; python_version >= "3.10" and python_version < "4.0" | ||
cffi==1.16.0 ; python_version >= "3.10" and python_version < "4.0" | ||
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "4.0" | ||
click==8.1.7 ; python_version >= "3.10" and python_version < "4.0" | ||
colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32") | ||
comm==0.2.2 ; python_version >= "3.10" and python_version < "4.0" | ||
debugpy==1.8.1 ; python_version >= "3.10" and python_version < "4.0" | ||
decorator==5.1.1 ; python_version >= "3.10" and python_version < "4.0" | ||
defusedxml==0.7.1 ; python_version >= "3.10" and python_version < "4.0" | ||
distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0" | ||
exceptiongroup==1.2.0 ; python_version >= "3.10" and python_version < "3.11" | ||
executing==2.0.1 ; python_version >= "3.10" and python_version < "4.0" | ||
fastjsonschema==2.19.1 ; python_version >= "3.10" and python_version < "4.0" | ||
fqdn==1.5.1 ; python_version >= "3.10" and python_version < "4" | ||
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "4.0" | ||
gitdb==4.0.11 ; python_version >= "3.10" and python_version < "4.0" | ||
gitpython==3.1.43 ; python_version >= "3.10" and python_version < "4.0" | ||
grpcio-tools==1.62.1 ; python_version >= "3.10" and python_version < "4.0" | ||
grpcio==1.62.1 ; python_version >= "3.10" and python_version < "4.0" | ||
h11==0.14.0 ; python_version >= "3.10" and python_version < "4.0" | ||
h2==4.1.0 ; python_version >= "3.10" and python_version < "4.0" | ||
hpack==4.0.0 ; python_version >= "3.10" and python_version < "4.0" | ||
httpcore==1.0.4 ; python_version >= "3.10" and python_version < "4.0" | ||
httpx==0.27.0 ; python_version >= "3.10" and python_version < "4.0" | ||
httpx[http2]==0.27.0 ; python_version >= "3.10" and python_version < "4.0" | ||
hyperframe==6.0.1 ; python_version >= "3.10" and python_version < "4.0" | ||
idna==3.7 ; python_version >= "3.10" and python_version < "4.0" | ||
ipykernel==6.29.3 ; python_version >= "3.10" and python_version < "4.0" | ||
ipython==8.22.2 ; python_version >= "3.10" and python_version < "4.0" | ||
ipywidgets==8.1.2 ; python_version >= "3.10" and python_version < "4.0" | ||
isoduration==20.11.0 ; python_version >= "3.10" and python_version < "4.0" | ||
jedi==0.19.1 ; python_version >= "3.10" and python_version < "4.0" | ||
jinja2==3.1.3 ; python_version >= "3.10" and python_version < "4.0" | ||
json5==0.9.24 ; python_version >= "3.10" and python_version < "4.0" | ||
jsonpointer==2.4 ; python_version >= "3.10" and python_version < "4.0" | ||
jsonschema-specifications==2023.12.1 ; python_version >= "3.10" and python_version < "4.0" | ||
jsonschema==4.21.1 ; python_version >= "3.10" and python_version < "4.0" | ||
jsonschema[format-nongpl]==4.21.1 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-client==8.6.1 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-console==6.6.3 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-core==5.7.2 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-events==0.10.0 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-lsp==2.2.4 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-server-terminals==0.5.3 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter-server==2.13.0 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyter==1.0.0 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyterlab-pygments==0.3.0 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyterlab-server==2.25.4 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyterlab-widgets==3.0.10 ; python_version >= "3.10" and python_version < "4.0" | ||
jupyterlab==4.1.5 ; python_version >= "3.10" and python_version < "4.0" | ||
langfuse==2.27.2 ; python_version >= "3.10" and python_version < "4.0" | ||
loguru==0.7.2 ; python_version >= "3.10" and python_version < "4.0" | ||
lxml==5.1.0 ; python_version >= "3.10" and python_version < "4.0" | ||
markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "4.0" | ||
markupsafe==2.1.5 ; python_version >= "3.10" and python_version < "4.0" | ||
matplotlib-inline==0.1.6 ; python_version >= "3.10" and python_version < "4.0" | ||
mdurl==0.1.2 ; python_version >= "3.10" and python_version < "4.0" | ||
mistune==3.0.2 ; python_version >= "3.10" and python_version < "4.0" | ||
multidict==6.0.5 ; python_version >= "3.10" and python_version < "4.0" | ||
nbclient==0.10.0 ; python_version >= "3.10" and python_version < "4.0" | ||
nbconvert==7.16.2 ; python_version >= "3.10" and python_version < "4.0" | ||
nbformat==5.10.3 ; python_version >= "3.10" and python_version < "4.0" | ||
nest-asyncio==1.6.0 ; python_version >= "3.10" and python_version < "4.0" | ||
notebook-shim==0.2.4 ; python_version >= "3.10" and python_version < "4.0" | ||
notebook==7.1.2 ; python_version >= "3.10" and python_version < "4.0" | ||
numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0" | ||
openai==1.14.2 ; python_version >= "3.10" and python_version < "4.0" | ||
overrides==7.7.0 ; python_version >= "3.10" and python_version < "4.0" | ||
packaging==23.2 ; python_version >= "3.10" and python_version < "4.0" | ||
pandas==2.2.0 ; python_version >= "3.10" and python_version < "4.0" | ||
pandocfilters==1.5.1 ; python_version >= "3.10" and python_version < "4.0" | ||
parso==0.8.3 ; python_version >= "3.10" and python_version < "4.0" | ||
pexpect==4.9.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten") | ||
pillow==10.2.0 ; python_version >= "3.10" and python_version < "4.0" | ||
platformdirs==4.2.0 ; python_version >= "3.10" and python_version < "4.0" | ||
portalocker==2.8.2 ; python_version >= "3.10" and python_version < "4.0" | ||
prometheus-client==0.20.0 ; python_version >= "3.10" and python_version < "4.0" | ||
prompt-toolkit==3.0.43 ; python_version >= "3.10" and python_version < "4.0" | ||
protobuf==4.25.3 ; python_version >= "3.10" and python_version < "4.0" | ||
psutil==5.9.8 ; python_version >= "3.10" and python_version < "4.0" | ||
ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4.0" and (sys_platform != "win32" and sys_platform != "emscripten" or os_name != "nt") | ||
pure-eval==0.2.2 ; python_version >= "3.10" and python_version < "4.0" | ||
pyarrow==15.0.2 ; python_version >= "3.10" and python_version < "4.0" | ||
pycparser==2.21 ; python_version >= "3.10" and python_version < "4.0" | ||
pydantic-core==2.16.3 ; python_version >= "3.10" and python_version < "4.0" | ||
pydantic==2.6.4 ; python_version >= "3.10" and python_version < "4.0" | ||
pydeck==0.8.0 ; python_version >= "3.10" and python_version < "4.0" | ||
pygments==2.17.2 ; python_version >= "3.10" and python_version < "4.0" | ||
python-dateutil==2.9.0.post0 ; python_version >= "3.10" and python_version < "4.0" | ||
python-dotenv==1.0.1 ; python_version >= "3.10" and python_version < "4.0" | ||
python-json-logger==2.0.7 ; python_version >= "3.10" and python_version < "4.0" | ||
pytz==2023.3.post1 ; python_version >= "3.10" and python_version < "4.0" | ||
pywin32==306 ; python_version >= "3.10" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32") and (platform_system == "Windows" or platform_python_implementation != "PyPy") | ||
pywinpty==2.0.13 ; python_version >= "3.10" and python_version < "4.0" and os_name == "nt" | ||
pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "4.0" | ||
pyzmq==25.1.2 ; python_version >= "3.10" and python_version < "4.0" | ||
qdrant-client==1.8.0 ; python_version >= "3.10" and python_version < "4.0" | ||
qtconsole==5.5.1 ; python_version >= "3.10" and python_version < "4.0" | ||
qtpy==2.4.1 ; python_version >= "3.10" and python_version < "4.0" | ||
referencing==0.34.0 ; python_version >= "3.10" and python_version < "4.0" | ||
regex==2023.12.25 ; python_version >= "3.10" and python_version < "4.0" | ||
requests==2.31.0 ; python_version >= "3.10" and python_version < "4.0" | ||
rfc3339-validator==0.1.4 ; python_version >= "3.10" and python_version < "4.0" | ||
rfc3986-validator==0.1.1 ; python_version >= "3.10" and python_version < "4.0" | ||
rich==13.7.1 ; python_version >= "3.10" and python_version < "4.0" | ||
rpds-py==0.18.0 ; python_version >= "3.10" and python_version < "4.0" | ||
send2trash==1.8.2 ; python_version >= "3.10" and python_version < "4.0" | ||
setuptools==69.2.0 ; python_version >= "3.10" and python_version < "4.0" | ||
six==1.16.0 ; python_version >= "3.10" and python_version < "4.0" | ||
smmap==5.0.1 ; python_version >= "3.10" and python_version < "4.0" | ||
sniffio==1.3.1 ; python_version >= "3.10" and python_version < "4.0" | ||
soupsieve==2.5 ; python_version >= "3.10" and python_version < "4.0" | ||
stack-data==0.6.3 ; python_version >= "3.10" and python_version < "4.0" | ||
streamlit==1.32.2 ; python_version >= "3.10" and python_version < "4.0" | ||
tenacity==8.2.3 ; python_version >= "3.10" and python_version < "4.0" | ||
terminado==0.18.1 ; python_version >= "3.10" and python_version < "4.0" | ||
tiktoken==0.6.0 ; python_version >= "3.10" and python_version < "4.0" | ||
tinycss2==1.2.1 ; python_version >= "3.10" and python_version < "4.0" | ||
toml==0.10.2 ; python_version >= "3.10" and python_version < "4.0" | ||
tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11" | ||
toolz==0.12.1 ; python_version >= "3.10" and python_version < "4.0" | ||
tornado==6.4 ; python_version >= "3.10" and python_version < "4.0" | ||
tqdm==4.66.2 ; python_version >= "3.10" and python_version < "4.0" | ||
traitlets==5.14.2 ; python_version >= "3.10" and python_version < "4.0" | ||
types-python-dateutil==2.9.0.20240316 ; python_version >= "3.10" and python_version < "4.0" | ||
typing-extensions==4.10.0 ; python_version >= "3.10" and python_version < "4.0" | ||
tzdata==2024.1 ; python_version >= "3.10" and python_version < "4.0" | ||
uri-template==1.3.0 ; python_version >= "3.10" and python_version < "4.0" | ||
urllib3==2.2.1 ; python_version >= "3.10" and python_version < "4.0" | ||
watchdog==4.0.0 ; python_version >= "3.10" and python_version < "4.0" and platform_system != "Darwin" | ||
wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "4.0" | ||
webcolors==1.13 ; python_version >= "3.10" and python_version < "4.0" | ||
webencodings==0.5.1 ; python_version >= "3.10" and python_version < "4.0" | ||
websocket-client==1.7.0 ; python_version >= "3.10" and python_version < "4.0" | ||
widgetsnbextension==4.0.10 ; python_version >= "3.10" and python_version < "4.0" | ||
win32-setctime==1.1.0 ; python_version >= "3.10" and python_version < "4.0" and sys_platform == "win32" | ||
wrapt==1.16.0 ; python_version >= "3.10" and python_version < "4.0" | ||
yarl==1.9.4 ; python_version >= "3.10" and python_version < "4.0" | ||
qdrant-client==1.8.0 | ||
python-dotenv==1.0.1 | ||
tiktoken==0.6.0 | ||
openai==1.14.2 | ||
loguru==0.7.2 | ||
pyyaml==6.0.1 | ||
streamlit==1.32.2 | ||
backoff==2.2.1 | ||
langfuse==2.27.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.