-
Notifications
You must be signed in to change notification settings - Fork 1
/
streamlit_app_blog.py
86 lines (71 loc) · 3.02 KB
/
streamlit_app_blog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from pdf_qa import PdfQA
from pathlib import Path
from tempfile import NamedTemporaryFile
import time
import shutil
from constants import *
# Streamlit app code
st.set_page_config(
page_title='Q&A Bot for PDF',
page_icon='🔖',
layout='wide',
initial_sidebar_state='auto',
)
if "pdf_qa_model" not in st.session_state:
st.session_state["pdf_qa_model"]: PdfQA = PdfQA() ## Intialisation
## To cache resource across multiple session
@st.cache_resource
def load_llm(llm,load_in_8bit):
if llm == LLM_FLAN_T5_SMALL:
return PdfQA.create_flan_t5_small(load_in_8bit)
elif llm == LLM_FLAN_T5_BASE:
return PdfQA.create_flan_t5_base(load_in_8bit)
elif llm == LLM_FLAN_T5_LARGE:
return PdfQA.create_flan_t5_large(load_in_8bit)
elif llm == LLM_FASTCHAT_T5_XL:
return PdfQA.create_fastchat_t5_xl(load_in_8bit)
else:
raise ValueError("Invalid LLM setting")
## To cache resource across multiple session
@st.cache_resource
def load_emb(emb):
if emb == EMB_INSTRUCTOR_XL:
return PdfQA.create_instructor_xl()
elif emb == EMB_SBERT_MPNET_BASE:
return PdfQA.create_sbert_mpnet()
elif emb == EMB_SBERT_MINILM:
pass ##ChromaDB takes care
else:
raise ValueError("Invalid embedding setting")
st.title("PDF Q&A (Self hosted LLMs)")
with st.sidebar:
emb = st.radio("**Select Embedding Model**", [EMB_INSTRUCTOR_XL, EMB_SBERT_MPNET_BASE,EMB_SBERT_MINILM],index=1)
llm = st.radio("**Select LLM Model**", [LLM_FASTCHAT_T5_XL, LLM_FLAN_T5_SMALL,LLM_FLAN_T5_BASE,LLM_FLAN_T5_LARGE,LLM_FLAN_T5_XL],index=2)
load_in_8bit = st.radio("**Load 8 bit**", [True, False],index=1)
pdf_file = st.file_uploader("**Upload PDF**", type="pdf")
if st.button("Submit") and pdf_file is not None:
with st.spinner(text="Uploading PDF and Generating Embeddings.."):
with NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
shutil.copyfileobj(pdf_file, tmp)
tmp_path = Path(tmp.name)
st.session_state["pdf_qa_model"].config = {
"pdf_path": str(tmp_path),
"embedding": emb,
"llm": llm,
"load_in_8bit": load_in_8bit
}
st.session_state["pdf_qa_model"].embedding = load_emb(emb)
st.session_state["pdf_qa_model"].llm = load_llm(llm,load_in_8bit)
st.session_state["pdf_qa_model"].init_embeddings()
st.session_state["pdf_qa_model"].init_models()
st.session_state["pdf_qa_model"].vector_db_pdf()
st.sidebar.success("PDF uploaded successfully")
question = st.text_input('Ask a question', 'Who are you?')
if st.button("Answer"):
try:
st.session_state["pdf_qa_model"].retreival_qa_chain()
answer = st.session_state["pdf_qa_model"].answer_query(question)
st.write(f"{answer}")
except Exception as e:
st.error(f"Error answering the question: {str(e)}")