-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathconversation.py
38 lines (30 loc) · 948 Bytes
/
conversation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from dotenv import load_dotenv
load_dotenv()
def create_conversation() -> ConversationalRetrievalChain:
persist_directory = 'db'
embeddings = OpenAIEmbeddings(
openai_api_key=os.getenv('OPENAI_API_KEY')
)
db = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings
)
memory = ConversationBufferMemory(
memory_key='chat_history',
return_messages=False
)
qa = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(),
chain_type='stuff',
retriever=db.as_retriever(),
memory=memory,
get_chat_history=lambda h: h,
verbose=True
)
return qa