-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
136 lines (104 loc) · 4.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import io
from pypdf import PdfReader
from dotenv import load_dotenv
from pydantic import BaseModel
from fastapi import FastAPI, UploadFile, File, Response, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import FAISS
from langchain.chains.question_answering import load_qa_chain
from langchain.callbacks import get_openai_callback
app = FastAPI()
# to prevent CORS error
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
load_dotenv()
# Connecting with PostgreSQL
# Bytes to KB conversion as we are stroing this in PostgreSQL (e.g. "420.05 KB")
# def bytes_to_kilobytes(bytes_value):
# return f"{bytes_value / 1024:.2f} KB"
# Declaring this globally as we will have to use this in 2 endpoints
knowledge_base = None
@app.get('/')
def index():
return {"msg": 'It is working...'}
@app.post('/upload')
async def pdf_upload(file: UploadFile = File(...)):
global knowledge_base
# Return errors for non-PDF files
if not file:
return Response({'error': 'Please upload a file'}, status_code=400)
if not file.filename.endswith('.pdf'):
return Response({'error': 'Uploaded file must be a PDF'}, status_code=401)
try:
pdf_binary_data = await file.read()
# converting unreadable binary file into bytes for pdfReader
pdf = io.BytesIO(pdf_binary_data)
pdf_content = PdfReader(pdf)
# Fetching content from pdf
text = ''
for i in pdf_content.pages:
text += i.extract_text()
# splitting content into chunks to perform semantic search
text_spiltter = CharacterTextSplitter(
separator='\n',
chunk_size=2000,
chunk_overlap=200,
length_function=len
)
chunks = text_spiltter.split_text(text)
# Creating embedding from chunks
embeddings = OpenAIEmbeddings()
# Setting up knowledge base based on chunks
knowledge_base = FAISS.from_texts(chunks, embeddings)
# If everything works correctly, we store the file metadata in PostgreSQL
if knowledge_base:
# bytes_value = file.size
# kilobytes_value = bytes_to_kilobytes(bytes_value)
# cursor.execute(
# "INSERT INTO file(file_name, file_size) VALUES (%s, %s)",
# (str(file.filename), kilobytes_value)
# )
# connection.commit()
return {'msg': 'Knowledge base set up successfully'}
return HTTPException(status_code=500, detail='Something went wrong settting up knowledge base.')
except Exception as error:
return HTTPException(status_code=500, detail=f'There was an error in server side {error}')
class Question(BaseModel):
question: str
# ^^ for some reason just question:'str' was causing an error
@app.post('/chat')
async def question_and_answer(question: Question):
global knowledge_base
try:
# Performing similarity search based upo user's question
if question.question and knowledge_base:
docs = knowledge_base.similarity_search(question.question)
# Using LLM model to setup the chain of the knowledge and retrive the appropriate answer
llm = OpenAI(model="gpt-3.5-turbo-instruct")
chain = load_qa_chain(llm, chain_type="stuff")
with get_openai_callback() as cb:
response = chain.run(input_documents=docs,
question=question)
print(cb)
# Returning the answer
return response
return HTTPException(status_code=401, detail='Please upload a PDF before you ask questions.')
except Exception as error:
return HTTPException(status_code=500, detail=f'There was error on server side {error}')
# @app.get('/pdfs')
# def get_pdfs():
# try:
# cursor.execute('SELECT * FROM file')
# files = cursor.fetchall()
# return {'files': files}
# except Exception as error:
# return HTTPException(status_code=500, detail=error)