forked from lorenzbaum/Hackathon-Pubquiz
-
Notifications
You must be signed in to change notification settings - Fork 0
/
db_source.py
29 lines (23 loc) · 861 Bytes
/
db_source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
from dotenv import load_dotenv
from langchain.chains import create_retrieval_chain
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from prompts import pubquiz_prompt
from llm import llm
from db import db
load_dotenv(override=True)
azure_api_key = os.getenv('AZURE_OPENAI_API_KEY')
azure_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
document_prompt = ChatPromptTemplate.from_template("""Content: {page_content}""")
document_chain = create_stuff_documents_chain(
llm=llm,
prompt=pubquiz_prompt,
document_prompt=document_prompt,
)
retriever = db.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)
def invoke_db(prompt: str):
result = retrieval_chain.invoke({"input": prompt})
print(result)
return result['answer']