-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchatbot.js
121 lines (99 loc) · 3.45 KB
/
chatbot.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import { OpenAI } from "langchain/llms/openai";
import { ConversationalRetrievalQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores/hnswlib";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { PDFLoader } from "langchain/document_loaders/fs/pdf";
import { CSVLoader } from "langchain/document_loaders/fs/csv";
// import { TextLoader } from "langchain/document_loaders/fs/text";
import { DocxLoader } from "langchain/document_loaders/fs/docx";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { HumanMessage, AIMessage } from "langchain/schema";
import { getFileExtension } from "./helper.js";
import * as fs from "fs";
export class AiModel {
constructor() {
this.chain = null;
this.chatHistory = new ChatMessageHistory([]);
this.totalTokens = 0;
}
fileUploaded = (filepath) => {
return fs.existsSync(filepath);
};
addHistory = (query, result) => {
const humanMessage = new HumanMessage(query);
const aiMessage = new AIMessage(result);
this.chatHistory.addMessage(humanMessage);
this.chatHistory.addMessage(aiMessage);
};
loadDocument = async (filePath, texts = null) => {
const model = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
callbacks: [
{
handleLLMEnd(output) {
this.totalTokens = output?.llmOutput?.tokenUsage?.totalTokens || 0;
},
},
],
});
const docs = texts ?? (await this.getDocument(filePath));
const vectorStore = await HNSWLib.fromDocuments(
docs,
new OpenAIEmbeddings()
);
const retriever = await vectorStore.asRetriever();
const memory = new BufferMemory({
chatHistory: this.chatHistory,
memoryKey: "chat_history",
inputKey: "question",
});
this.chain = ConversationalRetrievalQAChain.fromLLM(model, retriever, {
memory,
});
if (!texts) {
fs.rmSync(filePath);
}
return true;
};
getDocument = (filePath) => {
return new Promise(async (resolve) => {
const type = getFileExtension(filePath);
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 1000,
chunkOverlap: 0,
});
if (type === "csv" || type === "pdf") {
const loader =
type === "csv" ? new CSVLoader(filePath) : new PDFLoader(filePath);
const text = await loader.load();
const docs = await textSplitter.splitDocuments(text);
resolve(docs);
}
if (type === "docx") {
const loader = new DocxLoader(filePath);
const text = await loader.load();
const docs = await textSplitter.splitDocuments(text);
resolve(docs);
}
if (type === "txt") {
const text = fs.readFileSync(filePath, "utf8");
const docs = await textSplitter.createDocuments([text]);
resolve(docs);
}
});
};
ask = async (query) => {
if (!query) {
return "Your Question is Blank";
}
if (!this.chain) {
return "Please upload a document first or wait for the document to process";
}
const res = await this.chain.call({
question: query,
});
this.addHistory(query, res.text);
return { text: res.text, tokenUsage: this.totalTokens };
};
}