diff --git a/.gitignore b/.gitignore index df41a616..65a3b108 100755 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ .envrc .direnv -.streamlit/secrets.toml +*/.streamlit/secrets.toml .mypy_cache +__pycache__ node_modules build index.json diff --git a/README.md b/README.md index 674defb4..1bac931c 100644 --- a/README.md +++ b/README.md @@ -34,14 +34,20 @@ There are two main example folders ![react_frontend screenshot](./flask_react/react_frontend.png) -- streamlit (runs on localhost:8501) +- streamlit_vector (runs on localhost:8501) - `streamlit run streamlit_demo.py` - - creates a simple UI using streamlit + - creates a simple UI using streamlit and GPTSimpleVectorIndex - loads text from the documents folder (using `st.cache_resource`, so it only loads once) - provides an input text-box and a button to run the query - the string response is displayed after it finishes - want to see this example in action? Check it out [here](https://llama-index.streamlit.app/) +- streamlit_sql_sandbox (runs on localhost:8501) + - `streamlit run streamlit_demo.py` + - creates a streamlit app using a local SQL database about restaurant inspections in San Francisco ([data sample](https://docs.google.com/spreadsheets/d/1Ag5DBIviYiuRrt2yr3nXmbPFV-FOg5fDH5SM3ZEDnpw/edit#gid=780513932)) + - The "Setup" tab allows you to configure various LLM and LLama Index settings + - The "Llama Index" tab demos some basic Text2SQL capabilities using only Llama Index + - The "Langchain+Llama Index" tab uses a custom langchain agent, and uses the SQL index from Llama Index as a tool during conversations. ## Docker Each example contains a `Dockerfile`. You can run `docker build -t my_tag_name .` to build a python3.11-slim docker image inside your desired folder. It ends up being about 600MB-900MB depending on the example. diff --git a/flask_react/requirements.txt b/flask_react/requirements.txt index a81bb81d..fbc29046 100644 --- a/flask_react/requirements.txt +++ b/flask_react/requirements.txt @@ -1,5 +1,5 @@ Flask==2.2.3 Flask-Cors==3.0.10 -langchain==0.0.115 -llama-index==0.4.30 +langchain==0.0.123 +llama-index==0.4.39 PyPDF2==3.0.1 diff --git a/streamlit/requirements.txt b/streamlit/requirements.txt deleted file mode 100644 index 621be223..00000000 --- a/streamlit/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -langchain==0.0.115 -llama-index==0.4.30 -streamlit==1.19.0 diff --git a/streamlit/Dockerfile b/streamlit_sql_sandbox/Dockerfile similarity index 100% rename from streamlit/Dockerfile rename to streamlit_sql_sandbox/Dockerfile diff --git a/streamlit_sql_sandbox/constants.py b/streamlit_sql_sandbox/constants.py new file mode 100644 index 00000000..0381a3f8 --- /dev/null +++ b/streamlit_sql_sandbox/constants.py @@ -0,0 +1,24 @@ +DEFAULT_SQL_PATH = "sqlite:///sfscores.sqlite" +DEFAULT_BUSINESS_TABLE_DESCRP = ( + "This table gives information on the IDs, addresses, and other location " + "information for several restaurants in San Francisco. This table will " + "need to be referenced when users ask about specific businesses." +) +DEFAULT_VIOLATIONS_TABLE_DESCRP = ( + "This table gives information on which business IDs have recorded health violations, " + "including the date, risk, and description of each violation. The user may query " + "about specific businesses, whose names can be found by mapping the business_id " + "to the 'businesses' table." +) +DEFAULT_INSPECTIONS_TABLE_DESCRP = ( + "This table gives information on when each business ID was inspected, including " + "the score, date, and type of inspection. The user may query about specific " + "businesses, whose names can be found by mapping the business_id to the 'businesses' table." +) +DEFAULT_LC_TOOL_DESCRP = "Useful for when you want to answer queries about violations and inspections of businesses." + +DEFAULT_INGEST_DOCUMENT = ( + "The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. " + "The business achieved a score of 50. We two violations, a high risk " + "vermin infestation as well as a high risk food holding temperatures." +) \ No newline at end of file diff --git a/streamlit_sql_sandbox/requirements.txt b/streamlit_sql_sandbox/requirements.txt new file mode 100644 index 00000000..70dccfc0 --- /dev/null +++ b/streamlit_sql_sandbox/requirements.txt @@ -0,0 +1,4 @@ +langchain==0.0.123 +llama-index==0.4.39 +streamlit==1.19.0 +streamlit-chat==0.0.2.2 diff --git a/streamlit_sql_sandbox/sfscores.sqlite b/streamlit_sql_sandbox/sfscores.sqlite new file mode 100755 index 00000000..2a7b15e1 Binary files /dev/null and b/streamlit_sql_sandbox/sfscores.sqlite differ diff --git a/streamlit_sql_sandbox/sql_index.json b/streamlit_sql_sandbox/sql_index.json new file mode 100644 index 00000000..3e8071d5 --- /dev/null +++ b/streamlit_sql_sandbox/sql_index.json @@ -0,0 +1 @@ +{"index_struct_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "docstore": {"docs": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"text": null, "doc_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "embedding": null, "doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4", "extra_info": null, "context_dict": {}, "__type__": "sql"}}, "ref_doc_info": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4"}}}, "sql_context_container": {"context_dict": {"violations": "Schema of table violations:\nTable 'violations' has columns: business_id (TEXT), date (TEXT), ViolationTypeID (TEXT), risk_category (TEXT), description (TEXT) and foreign keys: .\nContext of table violations:\nThis table gives information on which business IDs have recorded health violations, including the date, risk, and description of each violation. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table.", "businesses": "Schema of table businesses:\nTable 'businesses' has columns: business_id (INTEGER), name (VARCHAR(64)), address (VARCHAR(50)), city (VARCHAR(23)), postal_code (VARCHAR(9)), latitude (FLOAT), longitude (FLOAT), phone_number (BIGINT), TaxCode (VARCHAR(4)), business_certificate (INTEGER), application_date (DATE), owner_name (VARCHAR(99)), owner_address (VARCHAR(74)), owner_city (VARCHAR(22)), owner_state (VARCHAR(14)), owner_zip (VARCHAR(15)) and foreign keys: .\nContext of table businesses:\nThis table gives information on the IDs, addresses, and other location information for several restaruants in San Fransisco. This table will need to be referenced when users ask about specific bussinesses.", "inspections": "Schema of table inspections:\nTable 'inspections' has columns: business_id (TEXT), Score (INTEGER), date (TEXT), type (VARCHAR(33)) and foreign keys: .\nContext of table inspections:\nThis table gives information on when each bussiness ID was inspected, including the score, date, and type of inspection. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table."}, "context_str": null}} \ No newline at end of file diff --git a/streamlit_sql_sandbox/streamlit_demo.py b/streamlit_sql_sandbox/streamlit_demo.py new file mode 100644 index 00000000..23d4e8d4 --- /dev/null +++ b/streamlit_sql_sandbox/streamlit_demo.py @@ -0,0 +1,137 @@ +import os +import streamlit as st +from streamlit_chat import message as st_message +from sqlalchemy import create_engine + +from langchain.agents import Tool, initialize_agent +from langchain.chains.conversation.memory import ConversationBufferMemory + +from llama_index import GPTSQLStructStoreIndex, LLMPredictor +from llama_index import SQLDatabase as llama_SQLDatabase +from llama_index.indices.struct_store import SQLContextContainerBuilder + +from constants import ( + DEFAULT_SQL_PATH, + DEFAULT_BUSINESS_TABLE_DESCRP, + DEFAULT_VIOLATIONS_TABLE_DESCRP, + DEFAULT_INSPECTIONS_TABLE_DESCRP, + DEFAULT_LC_TOOL_DESCRP +) +from utils import get_sql_index_tool, get_llm + +# NOTE: for local testing only, do NOT deploy with your key hardcoded +# to use this for yourself, create a file called .streamlit/secrets.toml with your api key +# Learn more about Streamlit on the docs: https://docs.streamlit.io/ +os.environ["OPENAI_API_KEY"] = st.secrets["openai_api_key"] + + +@st.cache_resource +def initialize_index(llm_name, model_temperature, table_context_dict, sql_path=DEFAULT_SQL_PATH): + """Create the GPTSQLStructStoreIndex object.""" + llm = get_llm(llm_name, model_temperature) + + engine = create_engine(sql_path) + sql_database = llama_SQLDatabase(engine) + + context_container = None + if table_context_dict is not None: + context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict) + context_container = context_builder.build_context_container() + + index = GPTSQLStructStoreIndex([], + sql_database=sql_database, + sql_context_container=context_container, + llm_predictor=LLMPredictor(llm=llm)) + + return index + + +@st.cache_resource +def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index): + """Create a (rather hacky) custom agent and sql_index tool.""" + sql_tool = Tool(name="SQL Index", + func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict), + description=lc_descrp) + + llm = get_llm(llm_name, model_temperature) + + memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) + + agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory) + + return agent_chain + + +st.title("🦙 Llama Index SQL Sandbox 🦙") +st.markdown(( + "This sandbox uses a sqlite database by default, containing information on health violations and inspections at restaurants in San Francisco.\n\n" + "The database contains three tables - businesses, inspections, and violations.\n\n" + "Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain." +)) + + +setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"]) + +with setup_tab: + st.subheader("LLM Setup") + model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1) + llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]) + + st.subheader("Table Setup") + business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP) + violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP) + inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP) + + table_context_dict = {"businesses": business_table_descrp, + "inspections": inspections_table_descrp, + "violations": violations_table_descrp} + + use_table_descrp = st.checkbox("Use table descriptions?", value=True) + lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP) + +with llama_tab: + st.subheader("Text2SQL with Llama Index") + if st.button("Initialize Index", key="init_index_1"): + st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None) + + if "llama_index" in st.session_state: + query_text = st.text_input("Query:") + if st.button("Run Query") and query_text: + with st.spinner("Getting response..."): + try: + response = st.session_state['llama_index'].query(query_text) + response_text = str(response) + response_sql = response.extra_info['sql_query'] + except Exception as e: + response_text = "Error running SQL Query." + response_sql = str(e) + + col1, col2 = st.columns(2) + with col1: + st.text("SQL Result:") + st.markdown(response_text) + + with col2: + st.text("SQL Query:") + st.markdown(response_sql) + +with lc_tab: + st.subheader("Langchain + Llama Index SQL Demo") + + if st.button("Initialize Agent"): + st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None) + st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, st.session_state['llama_index']) + st.session_state['chat_history'] = [] + + model_input = st.text_input("Message:") + if 'lc_agent' in st.session_state and st.button("Send"): + model_input = "User: " + model_input + st.session_state['chat_history'].append(model_input) + with st.spinner("Getting response..."): + response = st.session_state['lc_agent'].run(input=model_input) + st.session_state['chat_history'].append(response) + + if 'chat_history' in st.session_state: + for msg in st.session_state['chat_history']: + st_message(msg.split("User: ")[-1], is_user="User: " in msg) + \ No newline at end of file diff --git a/streamlit_sql_sandbox/utils.py b/streamlit_sql_sandbox/utils.py new file mode 100644 index 00000000..744730b2 --- /dev/null +++ b/streamlit_sql_sandbox/utils.py @@ -0,0 +1,24 @@ +from langchain import OpenAI +from langchain.chat_models import ChatOpenAI + + +def get_sql_index_tool(sql_index, table_context_dict): + table_context_str = "\n".join(table_context_dict.values()) + def run_sql_index_query(query_text): + try: + response = sql_index.query(query_text) + except Exception as e: + return f"Error running SQL {e}.\nNot able to retrieve answer." + text = str(response) + sql = response.extra_info['sql_query'] + return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n" + #return f"SQL Query Used: {sql}\nSQL Result: {text}\n" + return run_sql_index_query + + + +def get_llm(llm_name, model_temperature): + if llm_name == "text-davinci-003": + return OpenAI(temperature=model_temperature, model_name=llm_name) + else: + return ChatOpenAI(temperature=model_temperature, model_name=llm_name) \ No newline at end of file diff --git a/streamlit_vector/Dockerfile b/streamlit_vector/Dockerfile new file mode 100644 index 00000000..5f184a40 --- /dev/null +++ b/streamlit_vector/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.11.0-slim + +WORKDIR /app + +COPY . . + +RUN pip install -r requirements.txt && pip cache purge + +# Streamlit +CMD ["streamlit", "run", "streamlit_demo.py"] +EXPOSE 8501 diff --git a/streamlit/documents/paul_graham_essay.txt b/streamlit_vector/documents/paul_graham_essay.txt similarity index 100% rename from streamlit/documents/paul_graham_essay.txt rename to streamlit_vector/documents/paul_graham_essay.txt diff --git a/streamlit_vector/requirements.txt b/streamlit_vector/requirements.txt new file mode 100644 index 00000000..2ffa048e --- /dev/null +++ b/streamlit_vector/requirements.txt @@ -0,0 +1,3 @@ +langchain==0.0.123 +llama-index==0.4.39 +streamlit==1.19.0 diff --git a/streamlit/streamlit_demo.py b/streamlit_vector/streamlit_demo.py similarity index 100% rename from streamlit/streamlit_demo.py rename to streamlit_vector/streamlit_demo.py