-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
120 lines (89 loc) · 4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Main entrypoint for the app."""
from dotenv import load_dotenv
load_dotenv()
import logging
from typing import Optional
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.templating import Jinja2Templates
from langchain.vectorstores.base import VectorStore
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.base import BaseCallbackHandler
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import DeepLake
from langchain.chains import RetrievalQA
from callback import QuestionGenCallbackHandler, StreamingLLMCallbackHandler, MyCustomAsyncHandler
# from query_data import get_chain
from schemas import ChatResponse
app = FastAPI()
templates = Jinja2Templates(directory="templates")
vectorstore: Optional[VectorStore] = None
# code from the other implementation
embeddings = OpenAIEmbeddings(disallowed_special=())
retriever = None
@app.on_event("startup")
async def startup_event():
# Load vectorstore
global retriever
db = DeepLake(dataset_path="hub://lucasmanea/aptos-extended-new", read_only=True, embedding_function=embeddings)
retriever = db.as_retriever()
retriever.search_kwargs['distance_metric'] = 'cos'
retriever.search_kwargs['fetch_k'] = 100
retriever.search_kwargs['maximal_marginal_relevance'] = True
retriever.search_kwargs['k'] = 10
@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
class MyCustomSyncHandler(BaseCallbackHandler):
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_new_token(self, token: str, **kwargs) -> None:
resp = ChatResponse(sender="bot", message=token, type="stream")
await self.websocket.send_json(resp.dict())
@app.websocket("/chat")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
question_handler = QuestionGenCallbackHandler(websocket)
stream_handler = StreamingLLMCallbackHandler(websocket)
async_handler = MyCustomAsyncHandler()
sync_handler = MyCustomSyncHandler(websocket)
chat_history = []
model = ChatOpenAI(model_name='gpt-4', streaming=True, callbacks=[sync_handler], verbose=True) # switch to 'gpt-4'
# qa_chain = ConversationalRetrievalChain.from_llm(model, retriever=retriever)
# Use the below line instead of the above line to enable tracing
# Ensure `langchain-server` is running
# qa_chain = get_chain(vectorstore, question_handler, stream_handler, tracing=True)
while True:
try:
# Receive and send back the client message
question = await websocket.receive_text()
print(question)
resp = ChatResponse(sender="you", message=question, type="stream")
await websocket.send_json(resp.dict())
# chat_history.append([HumanMessage(content=question)])
# Construct a response
start_resp = ChatResponse(sender="bot", message="", type="start")
await websocket.send_json(start_resp.dict())
# result = await model.agenerate(
# chat_history
# )
chain = RetrievalQA.from_chain_type(llm=model, chain_type="stuff", retriever=retriever)
result = await chain.acall(inputs=question)
chat_history.append((question, result))
print(result)
end_resp = ChatResponse(sender="bot", message=result['result'], type="end")
print(end_resp.dict())
await websocket.send_json(end_resp.dict())
except WebSocketDisconnect:
logging.info("websocket disconnect")
break
except Exception as e:
logging.error(e)
resp = ChatResponse(
sender="bot",
message="Sorry, something went wrong. Try again.",
type="error",
)
await websocket.send_json(resp.dict())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9000)