Skip to content

Commit

Permalink
Migrate to Llama Index 0.5.X (#9)
Browse files Browse the repository at this point in the history
* Migrate to llama_index 0.5.X, update apps

* Update README with spaces links
  • Loading branch information
logan-markewich authored Apr 2, 2023
1 parent 5627a3d commit 3820c92
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 54 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ There are two main example folders
- loads text from the documents folder (using `st.cache_resource`, so it only loads once)
- provides an input text-box and a button to run the query
- the string response is displayed after it finishes
- want to see this example in action? Check it out [here](https://llama-index.streamlit.app/)
- want to see this example in action? Check it out [here](https://huggingface.co/spaces/llamaindex/llama_index_vector_demo)

- streamlit_sql_sandbox (runs on localhost:8501)
- `streamlit run streamlit_demo.py`
- creates a streamlit app using a local SQL database about restaurant inspections in San Francisco ([data sample](https://docs.google.com/spreadsheets/d/1Ag5DBIviYiuRrt2yr3nXmbPFV-FOg5fDH5SM3ZEDnpw/edit#gid=780513932))
- The "Setup" tab allows you to configure various LLM and LLama Index settings
- The "Llama Index" tab demos some basic Text2SQL capabilities using only Llama Index
- The "Langchain+Llama Index" tab uses a custom langchain agent, and uses the SQL index from Llama Index as a tool during conversations.
- Check out the huggingface space [here!](https://huggingface.co/spaces/llamaindex/llama_index_sql_sandbox)


## Docker
Each example contains a `Dockerfile`. You can run `docker build -t my_tag_name .` to build a python3.11-slim docker image inside your desired folder. It ends up being about 600MB-900MB depending on the example.
Expand Down
15 changes: 8 additions & 7 deletions flask_react/index_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from multiprocessing import Lock
from multiprocessing.managers import BaseManager
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex, Document
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex, Document, ServiceContext

index = None
stored_docs = {}
Expand All @@ -19,12 +19,13 @@
def initialize_index():
"""Create a new global index, or load one from the pre-set path."""
global index, stored_docs


service_context = ServiceContext.from_defaults(chunk_size_limit=512)
with lock:
if os.path.exists(index_name):
index = GPTSimpleVectorIndex.load_from_disk(index_name)
index = GPTSimpleVectorIndex.load_from_disk(index_name, service_context=service_context)
else:
index = GPTSimpleVectorIndex([])
index = GPTSimpleVectorIndex([], service_context=service_context)
index.save_to_disk(index_name)
if os.path.exists(pkl_name):
with open(pkl_name, "rb") as f:
Expand All @@ -44,11 +45,11 @@ def insert_into_index(doc_file_path, doc_id=None):
document = SimpleDirectoryReader(input_files=[doc_file_path]).load_data()[0]
if doc_id is not None:
document.doc_id = doc_id

# Keep track of stored docs -- llama_index doesn't make this easy
stored_docs[document.doc_id] = document.text[0:200] # only take the first 200 chars

with lock:
# Keep track of stored docs -- llama_index doesn't make this easy
stored_docs[document.doc_id] = document.text[0:200] # only take the first 200 chars

index.insert(document)
index.save_to_disk(index_name)

Expand Down
1 change: 1 addition & 0 deletions flask_react/launch_app.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ sleep 60
# start the flask server
python ./flask_demo.py &

# assumes you've ran npm install already (dockerfile does this during build)
cd react_frontend && npm run build && serve -s build
4 changes: 2 additions & 2 deletions flask_react/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Flask==2.2.3
Flask-Cors==3.0.10
langchain==0.0.123
llama-index==0.4.39
langchain==0.0.128
llama-index==0.5.4
PyPDF2==3.0.1
4 changes: 2 additions & 2 deletions streamlit_sql_sandbox/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
langchain==0.0.123
llama-index==0.4.39
langchain==0.0.128
llama-index==0.5.4
streamlit==1.19.0
streamlit-chat==0.0.2.2
44 changes: 22 additions & 22 deletions streamlit_sql_sandbox/streamlit_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.agents import Tool, initialize_agent
from langchain.chains.conversation.memory import ConversationBufferMemory

from llama_index import GPTSQLStructStoreIndex, LLMPredictor
from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext
from llama_index import SQLDatabase as llama_SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder

Expand All @@ -19,16 +19,11 @@
)
from utils import get_sql_index_tool, get_llm

# NOTE: for local testing only, do NOT deploy with your key hardcoded
# to use this for yourself, create a file called .streamlit/secrets.toml with your api key
# Learn more about Streamlit on the docs: https://docs.streamlit.io/
os.environ["OPENAI_API_KEY"] = st.secrets["openai_api_key"]


@st.cache_resource
def initialize_index(llm_name, model_temperature, table_context_dict, sql_path=DEFAULT_SQL_PATH):
def initialize_index(llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH):
"""Create the GPTSQLStructStoreIndex object."""
llm = get_llm(llm_name, model_temperature)
llm = get_llm(llm_name, model_temperature, api_key)

engine = create_engine(sql_path)
sql_database = llama_SQLDatabase(engine)
Expand All @@ -37,23 +32,24 @@ def initialize_index(llm_name, model_temperature, table_context_dict, sql_path=D
if table_context_dict is not None:
context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
context_container = context_builder.build_context_container()


service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
index = GPTSQLStructStoreIndex([],
sql_database=sql_database,
sql_context_container=context_container,
llm_predictor=LLMPredictor(llm=llm))
service_context=service_context)

return index


@st.cache_resource
def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
"""Create a (rather hacky) custom agent and sql_index tool."""
sql_tool = Tool(name="SQL Index",
func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
description=lc_descrp)

llm = get_llm(llm_name, model_temperature)
llm = get_llm(llm_name, model_temperature, api_key=api_key)

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

Expand All @@ -64,19 +60,23 @@ def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):

st.title("🦙 Llama Index SQL Sandbox 🦙")
st.markdown((
"This sandbox uses a sqlite database by default, containing information on health violations and inspections at restaurants in San Francisco.\n\n"
"The database contains three tables - businesses, inspections, and violations.\n\n"
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
"The database contains information on health violations and inspections at restaurants in San Francisco."
"This data is spread across three tables - businesses, inspections, and violations.\n\n"
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
"The other tabs will perform chatbot and text2sql operations.\n\n"
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
))


setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])

with setup_tab:
st.subheader("LLM Setup")
model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
api_key = st.text_input("Enter your OpenAI API key here", type="password")
llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])

model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)

st.subheader("Table Setup")
business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
Expand All @@ -92,10 +92,10 @@ def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):
with llama_tab:
st.subheader("Text2SQL with Llama Index")
if st.button("Initialize Index", key="init_index_1"):
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None)
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)

if "llama_index" in st.session_state:
query_text = st.text_input("Query:")
query_text = st.text_input("Query:", value="Which restaurant has the most violations?")
if st.button("Run Query") and query_text:
with st.spinner("Getting response..."):
try:
Expand All @@ -119,11 +119,11 @@ def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):
st.subheader("Langchain + Llama Index SQL Demo")

if st.button("Initialize Agent"):
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None)
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, st.session_state['llama_index'])
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, api_key, st.session_state['llama_index'])
st.session_state['chat_history'] = []

model_input = st.text_input("Message:")
model_input = st.text_input("Message:", value="Which restaurant has the most violations?")
if 'lc_agent' in st.session_state and st.button("Send"):
model_input = "User: " + model_input
st.session_state['chat_history'].append(model_input)
Expand All @@ -134,4 +134,4 @@ def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):
if 'chat_history' in st.session_state:
for msg in st.session_state['chat_history']:
st_message(msg.split("User: ")[-1], is_user="User: " in msg)


7 changes: 4 additions & 3 deletions streamlit_sql_sandbox/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI

Expand All @@ -16,9 +17,9 @@ def run_sql_index_query(query_text):
return run_sql_index_query



def get_llm(llm_name, model_temperature):
def get_llm(llm_name, model_temperature, api_key):
os.environ['OPENAI_API_KEY'] = api_key
if llm_name == "text-davinci-003":
return OpenAI(temperature=model_temperature, model_name=llm_name)
else:
return ChatOpenAI(temperature=model_temperature, model_name=llm_name)
return ChatOpenAI(temperature=model_temperature, model_name=llm_name)
4 changes: 2 additions & 2 deletions streamlit_vector/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
langchain==0.0.123
llama-index==0.4.39
langchain==0.0.128
llama-index==0.5.4
streamlit==1.19.0
40 changes: 25 additions & 15 deletions streamlit_vector/streamlit_demo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import os

import streamlit as st
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader

# NOTE: for local testing only, do NOT deploy with your key hardcoded
# to use this for yourself, create a file called .streamlit/secrets.toml with your api key
# Learn more about Streamlit on the docs: https://docs.streamlit.io/
os.environ["OPENAI_API_KEY"] = st.secrets["openai_api_key"]
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader, ServiceContext
from llama_index.llm_predictor.chatgpt import ChatGPTLLMPredictor


index_name = "./index.json"
Expand All @@ -15,11 +10,13 @@

@st.cache_resource
def initialize_index(index_name, documents_folder):
llm_predictor = ChatGPTLLMPredictor()
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
if os.path.exists(index_name):
index = GPTSimpleVectorIndex.load_from_disk(index_name)
index = GPTSimpleVectorIndex.load_from_disk(index_name, service_context=service_context)
else:
documents = SimpleDirectoryReader(documents_folder).load_data()
index = GPTSimpleVectorIndex(documents)
index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context)
index.save_to_disk(index_name)

return index
Expand All @@ -31,16 +28,29 @@ def query_index(_index, query_text):
return str(response)


# This should be cached and only fully runs once
index = initialize_index(index_name, documents_folder)


st.title("🦙 Llama Index Demo 🦙")
st.header("Welcome to the Llama Index Streamlit Demo")
st.text("Please enter a query about Paul Graham's essays")
st.write("Enter a query about Paul Graham's essays. You can check out the original essay [here](https://raw.githubusercontent.com/jerryjliu/llama_index/main/examples/paul_graham_essay/data/paul_graham_essay.txt). Your query will be answered using the essay as context, using embeddings from text-ada-002 and LLM completions from ChatGPT. You can read more about Llama Index and how this works in [our docs!](https://gpt-index.readthedocs.io/en/latest/index.html)")

index = None
api_key = st.text_input("Enter your OpenAI API key here:", type="password")
if api_key:
os.environ['OPENAI_API_KEY'] = api_key
index = initialize_index(index_name, documents_folder)


if index is None:
st.warning("Please enter your api key first.")

text = st.text_input("Query text:")
text = st.text_input("Query text:", value="What did the author do growing up?")

if st.button("Run Query") and text is not None:
response = query_index(index, text)
st.markdown(response)

llm_col, embed_col = st.columns(2)
with llm_col:
st.markdown(f"LLM Tokens Used: {index.service_context.llm_predictor._last_token_usage}")

with embed_col:
st.markdown(f"Embedding Tokens Used: {index.service_context.embed_model._last_token_usage}")

0 comments on commit 3820c92

Please sign in to comment.