-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathst_app.py
67 lines (56 loc) · 1.85 KB
/
st_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import streamlit as st
from langchain.prompts.chat import ChatPromptTemplate
from langchain.vectorstores.chroma import Chroma
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_openai.llms import OpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from pydantic.v1 import SecretStr
st.title("StrainDB RAG")
@st.cache_resource
def get_retriever():
"""A Chroma retriever that uses OpenAI embeddings"""
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=api_key)
return Chroma(
collection_name="strains",
embedding_function=embeddings,
persist_directory="./chroma/",
).as_retriever()
openai_api_key = st.text_input(
label="Enter your OpenAI API key",
max_chars=51,
help="Get an API key from https://platform.openai.com/account/api-keys",
type="password",
placeholder="sk-...",
)
query = st.text_input(
label="Enter your query",
help="Search for strains",
placeholder="What are the medical uses of Stunna",
)
prompt = ChatPromptTemplate.from_template("""
Answer the question based only on the following context:
{context}
Question: {question}
""")
if st.button("Search"):
if not openai_api_key:
st.warning("Please enter your OpenAI API key")
st.stop()
if not query:
st.warning("Please enter a query")
st.stop()
api_key = SecretStr(openai_api_key)
retriever = get_retriever()
model = OpenAI(temperature=0, api_key=api_key)
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
if not query:
st.warning("Please enter a query")
else:
with st.spinner("Searching..."):
st.write_stream(chain.stream(query))