Skip to content

Commit

Permalink
feat: improve doc removal (#130)
Browse files Browse the repository at this point in the history
* feat: improve doc removal

* refactor: change return type
  • Loading branch information
comoser authored Aug 9, 2023
1 parent 515cfff commit d2d7e05
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 29 deletions.
2 changes: 1 addition & 1 deletion packages/api/config/default.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"ai": {
"defaultTemperature": 0.2,
"defaultChatContextTTL": 604800,
"defaultChatContextTTL": 0,
"defaultTokenLimitForSummarization": 14500,
"defaultAiModel": "gpt-3.5-turbo-16k"
},
Expand Down
6 changes: 5 additions & 1 deletion packages/api/src/ai/ai.module.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { AiService } from '@/ai/facades/ai.service';
import { AgentConversationService } from '@/ai/services/agent-conversation.service';
import { AiService } from '@/ai/services/ai.service';
import { MemoryService } from '@/ai/services/memory.service';
import { RedisKeepAliveService } from '@/ai/services/redis-keep-alive.service';
import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service';
import { ToolService } from '@/ai/services/tool.service';
import { VectorDbService } from '@/ai/services/vector-db.service';
import { AppConfigModule } from '@/app-config/app-config.module';
import { CacheModule } from '@/cache/cache.module';
import { Module } from '@nestjs/common';
Expand All @@ -12,12 +13,15 @@ import { ScheduleModule } from '@nestjs/schedule';
@Module({
imports: [CacheModule, AppConfigModule, ScheduleModule.forRoot()],
providers: [
// Publicly exposed facades
AiService,
// Private services
MemoryService,
ToolService,
SimpleConversationChainService,
AgentConversationService,
RedisKeepAliveService,
VectorDbService,
],
exports: [AiService],
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { AgentConversationService } from '@/ai/services/agent-conversation.service';
import { MemoryService } from '@/ai/services/memory.service';
import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service';
import { VectorDbService } from '@/ai/services/vector-db.service';
import { AppConfigService } from '@/app-config/app-config.service';
import { RedisChatMemoryNotFoundException } from '@/chats/exceptions/redis-chat-memory-not-found.exception';
import { ChatDocument } from '@/common/types/chat';
import { Injectable } from '@nestjs/common';
import { PromptTemplate } from 'langchain';
import { ConfigService } from '@nestjs/config';
import { AgentExecutor } from 'langchain/agents';
import {
ConversationChain,
Expand All @@ -14,6 +15,7 @@ import {
} from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { Document } from 'langchain/document';
import { PromptTemplate } from 'langchain/prompts';
import { BaseMessage, ChainValues } from 'langchain/schema';

type AIExecutor = AgentExecutor | ConversationChain;
Expand All @@ -26,7 +28,9 @@ export class AiService {
private readonly simpleConversationChainService: SimpleConversationChainService,
private readonly agentConversationService: AgentConversationService,
private readonly appConfigService: AppConfigService,
private readonly memoryService: MemoryService
private readonly memoryService: MemoryService,
private readonly configService: ConfigService,
private readonly vectorDbService: VectorDbService
) {
this.llmModel = new ChatOpenAI({
temperature: this.appConfigService.getAiAppConfig().defaultTemperature,
Expand Down Expand Up @@ -117,6 +121,34 @@ Helpful answer:`
throw new RedisChatMemoryNotFoundException();
}

async removeVectorDBCollection(
roomId: string,
filename: string
): Promise<void> {
const vectorStore =
await this.vectorDbService.getVectorDbClientForExistingCollection(
roomId,
filename
);

const documentList = await vectorStore.collection.get();

await vectorStore.delete({ ids: documentList.ids });
}

async addDocumentsToVectorDBCollection(
roomId: string,
filename: string,
lcDocuments: Document[]
): Promise<void> {
const vectorStore = this.vectorDbService.getVectorDbClientForNewCollection(
roomId,
filename
);

await vectorStore.addDocuments(lcDocuments);
}

private async askAiToSummarize(roomId: string): Promise<ChainValues> {
const chain = await this.simpleConversationChainService.getChain(
roomId,
Expand Down
1 change: 1 addition & 0 deletions packages/api/src/ai/services/agent-conversation.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export class AgentConversationService {
summary?: string
) {
const agentDocumentTools = await this.toolService.getDocumentQATools(
roomId,
llmModel,
documents
);
Expand Down
20 changes: 8 additions & 12 deletions packages/api/src/ai/services/tool.service.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import { sanitizeFilename } from '@/common/constants/files';
import { VectorDbService } from '@/ai/services/vector-db.service';
import { ChatDocument } from '@/common/types/chat';
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { VectorDBQAChain } from 'langchain/chains';
import { BaseChatModel } from 'langchain/chat_models';
import { OpenAIEmbeddings } from 'langchain/embeddings/openai';
import { ChainTool } from 'langchain/tools';
import { Chroma } from 'langchain/vectorstores/chroma';

@Injectable()
export class ToolService {
constructor(private readonly configService: ConfigService) {}
constructor(private readonly vectorDbService: VectorDbService) {}

async getDocumentQATools(
roomId: string,
llmModel: BaseChatModel,
documents: ChatDocument[]
): Promise<ChainTool[]> {
const documentQATools = [];

for (const document of documents) {
const vectorStore = await Chroma.fromExistingCollection(
new OpenAIEmbeddings(),
{
url: this.configService.get('CHROMADB_CONNECTION_URL'),
collectionName: sanitizeFilename(document.meta.filename),
}
);
const vectorStore =
await this.vectorDbService.getVectorDbClientForExistingCollection(
roomId,
document.meta.filename
);

const chain = VectorDBQAChain.fromLLM(llmModel, vectorStore);

Expand Down
31 changes: 31 additions & 0 deletions packages/api/src/ai/services/vector-db.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { sanitizeFilename } from '@/common/constants/files';
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { OpenAIEmbeddings } from 'langchain/embeddings/openai';
import { Chroma } from 'langchain/vectorstores/chroma';

@Injectable()
export class VectorDbService {
constructor(private readonly configService: ConfigService) {}

getVectorDbClientForNewCollection(roomId: string, filename: string): Chroma {
return new Chroma(new OpenAIEmbeddings(), {
url: this.configService.get('CHROMADB_CONNECTION_URL'),
collectionName: this.getCollectionName(roomId, filename),
});
}

async getVectorDbClientForExistingCollection(
roomId: string,
filename: string
): Promise<Chroma> {
return Chroma.fromExistingCollection(new OpenAIEmbeddings(), {
url: this.configService.get('CHROMADB_CONNECTION_URL'),
collectionName: this.getCollectionName(roomId, filename),
});
}

private getCollectionName(roomId: string, filename: string) {
return `${roomId}_${sanitizeFilename(filename)}`;
}
}
2 changes: 1 addition & 1 deletion packages/api/src/chats/chat-socket.gateway.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AiService } from '@/ai/services/ai.service';
import { AiService } from '@/ai/facades/ai.service';
import { AppConfigService } from '@/app-config/app-config.service';
import { ChatsRepository } from '@/chats/chats.repository';
import { createSocketMessageResponseFactory } from '@/chats/factory/create-socket-message.factory';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { AiService } from '@/ai/services/ai.service';
import { AiService } from '@/ai/facades/ai.service';
import { ChatsRepository } from '@/chats/chats.repository';
import {
CSV_MIMETYPE,
DOCX_MIMETYPE,
PDF_MIMETYPE,
TEXT_MIMETYPE,
sanitizeFilename,
} from '@/common/constants/files';
import { CHAT_DOCUMENT_UPLOAD_QUEUE } from '@/common/constants/queues';
import { ChatDocUploadJob } from '@/common/jobs/chat-doc-upload.job';
Expand All @@ -19,9 +18,7 @@ import { CSVLoader } from 'langchain/document_loaders/fs/csv';
import { DocxLoader } from 'langchain/document_loaders/fs/docx';
import { PDFLoader } from 'langchain/document_loaders/fs/pdf';
import { TextLoader } from 'langchain/document_loaders/fs/text';
import { OpenAIEmbeddings } from 'langchain/embeddings/openai';
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { Chroma } from 'langchain/vectorstores/chroma';

@Processor(CHAT_DOCUMENT_UPLOAD_QUEUE)
export class TransformDocToVectorJobConsumer {
Expand Down Expand Up @@ -109,12 +106,11 @@ export class TransformDocToVectorJobConsumer {
document: ChatDocument,
lcDocuments: Document[]
) {
const vectorStore = new Chroma(new OpenAIEmbeddings(), {
url: this.configService.get('CHROMADB_CONNECTION_URL'),
collectionName: sanitizeFilename(document.meta.filename),
});

await vectorStore.addDocuments(lcDocuments);
await this.aiService.addDocumentsToVectorDBCollection(
roomId,
document.meta.filename,
lcDocuments
);

const vectorDBDocumentMetadata =
await this.aiService.askAiToDescribeDocument(lcDocuments);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AiService } from '@/ai/services/ai.service';
import { AiService } from '@/ai/facades/ai.service';
import { AppConfigService } from '@/app-config/app-config.service';
import { ChatsRepository } from '@/chats/chats.repository';
import { AddMessageToChatRequestDto } from '@/chats/dtos/add-message-to-chat.request.dto';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { AiService } from '@/ai/facades/ai.service';
import { ChatsRepository } from '@/chats/chats.repository';
import { ChatResponseDto } from '@/chats/dtos/chat.response.dto';
import { RemoveDocumentFromChatRequestDto } from '@/chats/dtos/remove-document-from-chat.request.dto';
Expand All @@ -9,7 +10,10 @@ import { Injectable } from '@nestjs/common';

@Injectable()
export class RemoveDocumentFromChatUsecase implements Usecase {
constructor(private readonly chatsRepository: ChatsRepository) {}
constructor(
private readonly chatsRepository: ChatsRepository,
private readonly aiService: AiService
) {}

async execute(
userId: string,
Expand All @@ -33,6 +37,11 @@ export class RemoveDocumentFromChatUsecase implements Usecase {
removeDocumentFromChatRequestDto
);

await this.aiService.removeVectorDBCollection(
roomId,
removeDocumentFromChatRequestDto.filename
);

return ChatResponseSchema.parse(chat);
} catch (e) {
throw new InternalServerErrorException(e.message);
Expand Down

0 comments on commit d2d7e05

Please sign in to comment.