Skip to content

Commit

Permalink
Add streamlit sql sanbox demo (#5)
Browse files Browse the repository at this point in the history
* Add streamlit SQL sandbox

* Update gitignore, remove files
  • Loading branch information
logan-markewich authored Mar 27, 2023
1 parent c705e9b commit f6cb57a
Show file tree
Hide file tree
Showing 15 changed files with 216 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
.envrc
.direnv
.streamlit/secrets.toml
*/.streamlit/secrets.toml
.mypy_cache
__pycache__
node_modules
build
index.json
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ There are two main example folders
![react_frontend screenshot](./flask_react/react_frontend.png)


- streamlit (runs on localhost:8501)
- streamlit_vector (runs on localhost:8501)
- `streamlit run streamlit_demo.py`
- creates a simple UI using streamlit
- creates a simple UI using streamlit and GPTSimpleVectorIndex
- loads text from the documents folder (using `st.cache_resource`, so it only loads once)
- provides an input text-box and a button to run the query
- the string response is displayed after it finishes
- want to see this example in action? Check it out [here](https://llama-index.streamlit.app/)

- streamlit_sql_sandbox (runs on localhost:8501)
- `streamlit run streamlit_demo.py`
- creates a streamlit app using a local SQL database about restaurant inspections in San Francisco ([data sample](https://docs.google.com/spreadsheets/d/1Ag5DBIviYiuRrt2yr3nXmbPFV-FOg5fDH5SM3ZEDnpw/edit#gid=780513932))
- The "Setup" tab allows you to configure various LLM and LLama Index settings
- The "Llama Index" tab demos some basic Text2SQL capabilities using only Llama Index
- The "Langchain+Llama Index" tab uses a custom langchain agent, and uses the SQL index from Llama Index as a tool during conversations.

## Docker
Each example contains a `Dockerfile`. You can run `docker build -t my_tag_name .` to build a python3.11-slim docker image inside your desired folder. It ends up being about 600MB-900MB depending on the example.
Expand Down
4 changes: 2 additions & 2 deletions flask_react/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Flask==2.2.3
Flask-Cors==3.0.10
langchain==0.0.115
llama-index==0.4.30
langchain==0.0.123
llama-index==0.4.39
PyPDF2==3.0.1
3 changes: 0 additions & 3 deletions streamlit/requirements.txt

This file was deleted.

File renamed without changes.
24 changes: 24 additions & 0 deletions streamlit_sql_sandbox/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
DEFAULT_SQL_PATH = "sqlite:///sfscores.sqlite"
DEFAULT_BUSINESS_TABLE_DESCRP = (
"This table gives information on the IDs, addresses, and other location "
"information for several restaurants in San Francisco. This table will "
"need to be referenced when users ask about specific businesses."
)
DEFAULT_VIOLATIONS_TABLE_DESCRP = (
"This table gives information on which business IDs have recorded health violations, "
"including the date, risk, and description of each violation. The user may query "
"about specific businesses, whose names can be found by mapping the business_id "
"to the 'businesses' table."
)
DEFAULT_INSPECTIONS_TABLE_DESCRP = (
"This table gives information on when each business ID was inspected, including "
"the score, date, and type of inspection. The user may query about specific "
"businesses, whose names can be found by mapping the business_id to the 'businesses' table."
)
DEFAULT_LC_TOOL_DESCRP = "Useful for when you want to answer queries about violations and inspections of businesses."

DEFAULT_INGEST_DOCUMENT = (
"The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
"The business achieved a score of 50. We two violations, a high risk "
"vermin infestation as well as a high risk food holding temperatures."
)
4 changes: 4 additions & 0 deletions streamlit_sql_sandbox/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
langchain==0.0.123
llama-index==0.4.39
streamlit==1.19.0
streamlit-chat==0.0.2.2
Binary file added streamlit_sql_sandbox/sfscores.sqlite
Binary file not shown.
1 change: 1 addition & 0 deletions streamlit_sql_sandbox/sql_index.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"index_struct_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "docstore": {"docs": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"text": null, "doc_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "embedding": null, "doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4", "extra_info": null, "context_dict": {}, "__type__": "sql"}}, "ref_doc_info": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4"}}}, "sql_context_container": {"context_dict": {"violations": "Schema of table violations:\nTable 'violations' has columns: business_id (TEXT), date (TEXT), ViolationTypeID (TEXT), risk_category (TEXT), description (TEXT) and foreign keys: .\nContext of table violations:\nThis table gives information on which business IDs have recorded health violations, including the date, risk, and description of each violation. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table.", "businesses": "Schema of table businesses:\nTable 'businesses' has columns: business_id (INTEGER), name (VARCHAR(64)), address (VARCHAR(50)), city (VARCHAR(23)), postal_code (VARCHAR(9)), latitude (FLOAT), longitude (FLOAT), phone_number (BIGINT), TaxCode (VARCHAR(4)), business_certificate (INTEGER), application_date (DATE), owner_name (VARCHAR(99)), owner_address (VARCHAR(74)), owner_city (VARCHAR(22)), owner_state (VARCHAR(14)), owner_zip (VARCHAR(15)) and foreign keys: .\nContext of table businesses:\nThis table gives information on the IDs, addresses, and other location information for several restaruants in San Fransisco. This table will need to be referenced when users ask about specific bussinesses.", "inspections": "Schema of table inspections:\nTable 'inspections' has columns: business_id (TEXT), Score (INTEGER), date (TEXT), type (VARCHAR(33)) and foreign keys: .\nContext of table inspections:\nThis table gives information on when each bussiness ID was inspected, including the score, date, and type of inspection. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table."}, "context_str": null}}
137 changes: 137 additions & 0 deletions streamlit_sql_sandbox/streamlit_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import streamlit as st
from streamlit_chat import message as st_message
from sqlalchemy import create_engine

from langchain.agents import Tool, initialize_agent
from langchain.chains.conversation.memory import ConversationBufferMemory

from llama_index import GPTSQLStructStoreIndex, LLMPredictor
from llama_index import SQLDatabase as llama_SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder

from constants import (
DEFAULT_SQL_PATH,
DEFAULT_BUSINESS_TABLE_DESCRP,
DEFAULT_VIOLATIONS_TABLE_DESCRP,
DEFAULT_INSPECTIONS_TABLE_DESCRP,
DEFAULT_LC_TOOL_DESCRP
)
from utils import get_sql_index_tool, get_llm

# NOTE: for local testing only, do NOT deploy with your key hardcoded
# to use this for yourself, create a file called .streamlit/secrets.toml with your api key
# Learn more about Streamlit on the docs: https://docs.streamlit.io/
os.environ["OPENAI_API_KEY"] = st.secrets["openai_api_key"]


@st.cache_resource
def initialize_index(llm_name, model_temperature, table_context_dict, sql_path=DEFAULT_SQL_PATH):
"""Create the GPTSQLStructStoreIndex object."""
llm = get_llm(llm_name, model_temperature)

engine = create_engine(sql_path)
sql_database = llama_SQLDatabase(engine)

context_container = None
if table_context_dict is not None:
context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
context_container = context_builder.build_context_container()

index = GPTSQLStructStoreIndex([],
sql_database=sql_database,
sql_context_container=context_container,
llm_predictor=LLMPredictor(llm=llm))

return index


@st.cache_resource
def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):
"""Create a (rather hacky) custom agent and sql_index tool."""
sql_tool = Tool(name="SQL Index",
func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
description=lc_descrp)

llm = get_llm(llm_name, model_temperature)

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory)

return agent_chain


st.title("🦙 Llama Index SQL Sandbox 🦙")
st.markdown((
"This sandbox uses a sqlite database by default, containing information on health violations and inspections at restaurants in San Francisco.\n\n"
"The database contains three tables - businesses, inspections, and violations.\n\n"
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
))


setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])

with setup_tab:
st.subheader("LLM Setup")
model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])

st.subheader("Table Setup")
business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP)

table_context_dict = {"businesses": business_table_descrp,
"inspections": inspections_table_descrp,
"violations": violations_table_descrp}

use_table_descrp = st.checkbox("Use table descriptions?", value=True)
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)

with llama_tab:
st.subheader("Text2SQL with Llama Index")
if st.button("Initialize Index", key="init_index_1"):
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None)

if "llama_index" in st.session_state:
query_text = st.text_input("Query:")
if st.button("Run Query") and query_text:
with st.spinner("Getting response..."):
try:
response = st.session_state['llama_index'].query(query_text)
response_text = str(response)
response_sql = response.extra_info['sql_query']
except Exception as e:
response_text = "Error running SQL Query."
response_sql = str(e)

col1, col2 = st.columns(2)
with col1:
st.text("SQL Result:")
st.markdown(response_text)

with col2:
st.text("SQL Query:")
st.markdown(response_sql)

with lc_tab:
st.subheader("Langchain + Llama Index SQL Demo")

if st.button("Initialize Agent"):
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None)
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, st.session_state['llama_index'])
st.session_state['chat_history'] = []

model_input = st.text_input("Message:")
if 'lc_agent' in st.session_state and st.button("Send"):
model_input = "User: " + model_input
st.session_state['chat_history'].append(model_input)
with st.spinner("Getting response..."):
response = st.session_state['lc_agent'].run(input=model_input)
st.session_state['chat_history'].append(response)

if 'chat_history' in st.session_state:
for msg in st.session_state['chat_history']:
st_message(msg.split("User: ")[-1], is_user="User: " in msg)

24 changes: 24 additions & 0 deletions streamlit_sql_sandbox/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI


def get_sql_index_tool(sql_index, table_context_dict):
table_context_str = "\n".join(table_context_dict.values())
def run_sql_index_query(query_text):
try:
response = sql_index.query(query_text)
except Exception as e:
return f"Error running SQL {e}.\nNot able to retrieve answer."
text = str(response)
sql = response.extra_info['sql_query']
return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
#return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
return run_sql_index_query



def get_llm(llm_name, model_temperature):
if llm_name == "text-davinci-003":
return OpenAI(temperature=model_temperature, model_name=llm_name)
else:
return ChatOpenAI(temperature=model_temperature, model_name=llm_name)
11 changes: 11 additions & 0 deletions streamlit_vector/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.11.0-slim

WORKDIR /app

COPY . .

RUN pip install -r requirements.txt && pip cache purge

# Streamlit
CMD ["streamlit", "run", "streamlit_demo.py"]
EXPOSE 8501
File renamed without changes.
3 changes: 3 additions & 0 deletions streamlit_vector/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
langchain==0.0.123
llama-index==0.4.39
streamlit==1.19.0
File renamed without changes.

0 comments on commit f6cb57a

Please sign in to comment.