-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
150 lines (123 loc) · 4.77 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
from langchain.document_loaders import PyPDFLoader, CSVLoader
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from dotenv import load_dotenv
import os
import uuid
from pathlib import Path
# Load environment variables
load_dotenv()
# Initialize OpenAI and ChromaDB
openai_api_key = os.getenv("OPENAI_API_KEY")
persist_directory = ".chromadb"
# LangChain Components
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
vector_store = Chroma(persist_directory=persist_directory,
embedding_function=embeddings)
retriever = vector_store.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=OpenAI(model_kwargs={
"model": "gpt-3.5-turbo-instruct"}, openai_api_key=openai_api_key),
retriever=retriever,
return_source_documents=True # Ensure source documents are included
)
# Directory to store uploaded files
UPLOAD_DIR = Path.cwd() / "uploaded_files"
UPLOAD_DIR.mkdir(exist_ok=True)
# Initialize session state
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = []
if "cited_files" not in st.session_state:
st.session_state.cited_files = set()
# Helper functions
def process_pdf(file, file_name):
"""Processes a PDF file using LangChain's PyPDFLoader."""
file_path = UPLOAD_DIR / file_name
with open(file_path, "wb") as tmp_file:
tmp_file.write(file.read())
loader = PyPDFLoader(str(file_path))
documents = loader.load()
for doc in documents:
doc.metadata["source"] = str(file_path)
return documents
def process_csv(file, file_name):
"""Processes a CSV file using LangChain's CSVLoader."""
file_path = UPLOAD_DIR / file_name
with open(file_path, "wb") as tmp_file:
tmp_file.write(file.read())
loader = CSVLoader(file_path=str(file_path))
documents = loader.load()
for doc in documents:
doc.metadata["source"] = str(file_path)
return documents
def save_to_vector_store(documents):
"""Saves documents to the vector store."""
# Ensure unique IDs for each document
ids = [str(uuid.uuid4()) for _ in documents]
vector_store.add_documents(documents, ids=ids)
vector_store.persist()
st.write("Documents successfully added to the vector store.")
def get_file_download_button(file_path, unique_key):
"""Generates a Streamlit download button for a file with a unique key."""
file_name = Path(file_path).name
with open(file_path, "rb") as file:
file_data = file.read()
st.download_button(
label=f"Download {file_name}",
data=file_data,
file_name=file_name,
mime="application/octet-stream",
key=f"download-{uuid.uuid1()}" # Unique key for each button
)
# Streamlit App UI
st.title("QA with Your Documents")
# File Upload
uploaded_files = st.file_uploader(
"Upload PDF or CSV files", type=["pdf", "csv"], accept_multiple_files=True)
if uploaded_files:
all_documents = []
for uploaded_file in uploaded_files:
file_name = uploaded_file.name
file_extension = file_name.split(".")[-1]
# Process the file based on extension
if file_extension == "pdf":
documents = process_pdf(uploaded_file, file_name)
elif file_extension == "csv":
documents = process_csv(uploaded_file, file_name)
else:
st.error(f"Unsupported file format: {file_extension}")
continue
all_documents.extend(documents)
# Save documents to vector store
if all_documents:
save_to_vector_store(all_documents)
st.success(
"All files have been uploaded and stored in the vector database.")
# Chat Section
st.header("Chat with Your Documents")
query = st.text_input("Ask a question about the uploaded documents:")
if st.button("Submit Query") and query:
try:
result = qa_chain({"query": query})
answer = result["result"]
sources = result.get("source_documents", [])
st.write(f"**Answer:** {answer}")
if sources:
# Update session state for cited files
cited_files = {Path(doc.metadata.get(
"source", "")).name for doc in sources}
st.session_state.cited_files.clear()
st.session_state.cited_files.update(cited_files)
# Display cited files for download
st.write("### Cited Files:")
for file_name in st.session_state.cited_files:
file_path = UPLOAD_DIR / file_name
if file_path.exists():
get_file_download_button(file_path, file_name)
else:
st.write("No sources available.")
except Exception as e:
st.error(f"Error: {e}")