From d2d7e0508b9c3fd5bd6aa0a7239874594fcf7ace Mon Sep 17 00:00:00 2001 From: David Alecrim Date: Wed, 9 Aug 2023 12:45:37 +0100 Subject: [PATCH] feat: improve doc removal (#130) * feat: improve doc removal * refactor: change return type --- packages/api/config/default.json | 2 +- packages/api/src/ai/ai.module.ts | 6 +++- .../ai/{services => facades}/ai.service.ts | 36 +++++++++++++++++-- .../ai/services/agent-conversation.service.ts | 1 + packages/api/src/ai/services/tool.service.ts | 20 +++++------ .../api/src/ai/services/vector-db.service.ts | 31 ++++++++++++++++ packages/api/src/chats/chat-socket.gateway.ts | 2 +- .../transform-doc-to-vector.job-consumer.ts | 16 ++++----- .../usecases/add-message-to-chat.usecase.ts | 2 +- .../remove-document-from-chat.usecase.ts | 11 +++++- 10 files changed, 98 insertions(+), 29 deletions(-) rename packages/api/src/ai/{services => facades}/ai.service.ts (79%) create mode 100644 packages/api/src/ai/services/vector-db.service.ts diff --git a/packages/api/config/default.json b/packages/api/config/default.json index 11d8b45..da3eed8 100644 --- a/packages/api/config/default.json +++ b/packages/api/config/default.json @@ -1,7 +1,7 @@ { "ai": { "defaultTemperature": 0.2, - "defaultChatContextTTL": 604800, + "defaultChatContextTTL": 0, "defaultTokenLimitForSummarization": 14500, "defaultAiModel": "gpt-3.5-turbo-16k" }, diff --git a/packages/api/src/ai/ai.module.ts b/packages/api/src/ai/ai.module.ts index 77ebf29..606792a 100644 --- a/packages/api/src/ai/ai.module.ts +++ b/packages/api/src/ai/ai.module.ts @@ -1,9 +1,10 @@ +import { AiService } from '@/ai/facades/ai.service'; import { AgentConversationService } from '@/ai/services/agent-conversation.service'; -import { AiService } from '@/ai/services/ai.service'; import { MemoryService } from '@/ai/services/memory.service'; import { RedisKeepAliveService } from '@/ai/services/redis-keep-alive.service'; import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service'; import { ToolService } from '@/ai/services/tool.service'; +import { VectorDbService } from '@/ai/services/vector-db.service'; import { AppConfigModule } from '@/app-config/app-config.module'; import { CacheModule } from '@/cache/cache.module'; import { Module } from '@nestjs/common'; @@ -12,12 +13,15 @@ import { ScheduleModule } from '@nestjs/schedule'; @Module({ imports: [CacheModule, AppConfigModule, ScheduleModule.forRoot()], providers: [ + // Publicly exposed facades AiService, + // Private services MemoryService, ToolService, SimpleConversationChainService, AgentConversationService, RedisKeepAliveService, + VectorDbService, ], exports: [AiService], }) diff --git a/packages/api/src/ai/services/ai.service.ts b/packages/api/src/ai/facades/ai.service.ts similarity index 79% rename from packages/api/src/ai/services/ai.service.ts rename to packages/api/src/ai/facades/ai.service.ts index 9d270f8..4141c38 100644 --- a/packages/api/src/ai/services/ai.service.ts +++ b/packages/api/src/ai/facades/ai.service.ts @@ -1,11 +1,12 @@ import { AgentConversationService } from '@/ai/services/agent-conversation.service'; import { MemoryService } from '@/ai/services/memory.service'; import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service'; +import { VectorDbService } from '@/ai/services/vector-db.service'; import { AppConfigService } from '@/app-config/app-config.service'; import { RedisChatMemoryNotFoundException } from '@/chats/exceptions/redis-chat-memory-not-found.exception'; import { ChatDocument } from '@/common/types/chat'; import { Injectable } from '@nestjs/common'; -import { PromptTemplate } from 'langchain'; +import { ConfigService } from '@nestjs/config'; import { AgentExecutor } from 'langchain/agents'; import { ConversationChain, @@ -14,6 +15,7 @@ import { } from 'langchain/chains'; import { ChatOpenAI } from 'langchain/chat_models/openai'; import { Document } from 'langchain/document'; +import { PromptTemplate } from 'langchain/prompts'; import { BaseMessage, ChainValues } from 'langchain/schema'; type AIExecutor = AgentExecutor | ConversationChain; @@ -26,7 +28,9 @@ export class AiService { private readonly simpleConversationChainService: SimpleConversationChainService, private readonly agentConversationService: AgentConversationService, private readonly appConfigService: AppConfigService, - private readonly memoryService: MemoryService + private readonly memoryService: MemoryService, + private readonly configService: ConfigService, + private readonly vectorDbService: VectorDbService ) { this.llmModel = new ChatOpenAI({ temperature: this.appConfigService.getAiAppConfig().defaultTemperature, @@ -117,6 +121,34 @@ Helpful answer:` throw new RedisChatMemoryNotFoundException(); } + async removeVectorDBCollection( + roomId: string, + filename: string + ): Promise { + const vectorStore = + await this.vectorDbService.getVectorDbClientForExistingCollection( + roomId, + filename + ); + + const documentList = await vectorStore.collection.get(); + + await vectorStore.delete({ ids: documentList.ids }); + } + + async addDocumentsToVectorDBCollection( + roomId: string, + filename: string, + lcDocuments: Document[] + ): Promise { + const vectorStore = this.vectorDbService.getVectorDbClientForNewCollection( + roomId, + filename + ); + + await vectorStore.addDocuments(lcDocuments); + } + private async askAiToSummarize(roomId: string): Promise { const chain = await this.simpleConversationChainService.getChain( roomId, diff --git a/packages/api/src/ai/services/agent-conversation.service.ts b/packages/api/src/ai/services/agent-conversation.service.ts index 19f3890..fe98d30 100644 --- a/packages/api/src/ai/services/agent-conversation.service.ts +++ b/packages/api/src/ai/services/agent-conversation.service.ts @@ -45,6 +45,7 @@ export class AgentConversationService { summary?: string ) { const agentDocumentTools = await this.toolService.getDocumentQATools( + roomId, llmModel, documents ); diff --git a/packages/api/src/ai/services/tool.service.ts b/packages/api/src/ai/services/tool.service.ts index 370f61d..53e3d86 100644 --- a/packages/api/src/ai/services/tool.service.ts +++ b/packages/api/src/ai/services/tool.service.ts @@ -1,31 +1,27 @@ -import { sanitizeFilename } from '@/common/constants/files'; +import { VectorDbService } from '@/ai/services/vector-db.service'; import { ChatDocument } from '@/common/types/chat'; import { Injectable } from '@nestjs/common'; -import { ConfigService } from '@nestjs/config'; import { VectorDBQAChain } from 'langchain/chains'; import { BaseChatModel } from 'langchain/chat_models'; -import { OpenAIEmbeddings } from 'langchain/embeddings/openai'; import { ChainTool } from 'langchain/tools'; -import { Chroma } from 'langchain/vectorstores/chroma'; @Injectable() export class ToolService { - constructor(private readonly configService: ConfigService) {} + constructor(private readonly vectorDbService: VectorDbService) {} async getDocumentQATools( + roomId: string, llmModel: BaseChatModel, documents: ChatDocument[] ): Promise { const documentQATools = []; for (const document of documents) { - const vectorStore = await Chroma.fromExistingCollection( - new OpenAIEmbeddings(), - { - url: this.configService.get('CHROMADB_CONNECTION_URL'), - collectionName: sanitizeFilename(document.meta.filename), - } - ); + const vectorStore = + await this.vectorDbService.getVectorDbClientForExistingCollection( + roomId, + document.meta.filename + ); const chain = VectorDBQAChain.fromLLM(llmModel, vectorStore); diff --git a/packages/api/src/ai/services/vector-db.service.ts b/packages/api/src/ai/services/vector-db.service.ts new file mode 100644 index 0000000..8f80fcc --- /dev/null +++ b/packages/api/src/ai/services/vector-db.service.ts @@ -0,0 +1,31 @@ +import { sanitizeFilename } from '@/common/constants/files'; +import { Injectable } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { OpenAIEmbeddings } from 'langchain/embeddings/openai'; +import { Chroma } from 'langchain/vectorstores/chroma'; + +@Injectable() +export class VectorDbService { + constructor(private readonly configService: ConfigService) {} + + getVectorDbClientForNewCollection(roomId: string, filename: string): Chroma { + return new Chroma(new OpenAIEmbeddings(), { + url: this.configService.get('CHROMADB_CONNECTION_URL'), + collectionName: this.getCollectionName(roomId, filename), + }); + } + + async getVectorDbClientForExistingCollection( + roomId: string, + filename: string + ): Promise { + return Chroma.fromExistingCollection(new OpenAIEmbeddings(), { + url: this.configService.get('CHROMADB_CONNECTION_URL'), + collectionName: this.getCollectionName(roomId, filename), + }); + } + + private getCollectionName(roomId: string, filename: string) { + return `${roomId}_${sanitizeFilename(filename)}`; + } +} diff --git a/packages/api/src/chats/chat-socket.gateway.ts b/packages/api/src/chats/chat-socket.gateway.ts index 85fe734..7ab26ce 100644 --- a/packages/api/src/chats/chat-socket.gateway.ts +++ b/packages/api/src/chats/chat-socket.gateway.ts @@ -1,4 +1,4 @@ -import { AiService } from '@/ai/services/ai.service'; +import { AiService } from '@/ai/facades/ai.service'; import { AppConfigService } from '@/app-config/app-config.service'; import { ChatsRepository } from '@/chats/chats.repository'; import { createSocketMessageResponseFactory } from '@/chats/factory/create-socket-message.factory'; diff --git a/packages/api/src/chats/job-consumers/transform-doc-to-vector.job-consumer.ts b/packages/api/src/chats/job-consumers/transform-doc-to-vector.job-consumer.ts index effbf5f..866f5db 100644 --- a/packages/api/src/chats/job-consumers/transform-doc-to-vector.job-consumer.ts +++ b/packages/api/src/chats/job-consumers/transform-doc-to-vector.job-consumer.ts @@ -1,11 +1,10 @@ -import { AiService } from '@/ai/services/ai.service'; +import { AiService } from '@/ai/facades/ai.service'; import { ChatsRepository } from '@/chats/chats.repository'; import { CSV_MIMETYPE, DOCX_MIMETYPE, PDF_MIMETYPE, TEXT_MIMETYPE, - sanitizeFilename, } from '@/common/constants/files'; import { CHAT_DOCUMENT_UPLOAD_QUEUE } from '@/common/constants/queues'; import { ChatDocUploadJob } from '@/common/jobs/chat-doc-upload.job'; @@ -19,9 +18,7 @@ import { CSVLoader } from 'langchain/document_loaders/fs/csv'; import { DocxLoader } from 'langchain/document_loaders/fs/docx'; import { PDFLoader } from 'langchain/document_loaders/fs/pdf'; import { TextLoader } from 'langchain/document_loaders/fs/text'; -import { OpenAIEmbeddings } from 'langchain/embeddings/openai'; import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'; -import { Chroma } from 'langchain/vectorstores/chroma'; @Processor(CHAT_DOCUMENT_UPLOAD_QUEUE) export class TransformDocToVectorJobConsumer { @@ -109,12 +106,11 @@ export class TransformDocToVectorJobConsumer { document: ChatDocument, lcDocuments: Document[] ) { - const vectorStore = new Chroma(new OpenAIEmbeddings(), { - url: this.configService.get('CHROMADB_CONNECTION_URL'), - collectionName: sanitizeFilename(document.meta.filename), - }); - - await vectorStore.addDocuments(lcDocuments); + await this.aiService.addDocumentsToVectorDBCollection( + roomId, + document.meta.filename, + lcDocuments + ); const vectorDBDocumentMetadata = await this.aiService.askAiToDescribeDocument(lcDocuments); diff --git a/packages/api/src/chats/usecases/add-message-to-chat.usecase.ts b/packages/api/src/chats/usecases/add-message-to-chat.usecase.ts index e809bfd..7a1f09f 100644 --- a/packages/api/src/chats/usecases/add-message-to-chat.usecase.ts +++ b/packages/api/src/chats/usecases/add-message-to-chat.usecase.ts @@ -1,4 +1,4 @@ -import { AiService } from '@/ai/services/ai.service'; +import { AiService } from '@/ai/facades/ai.service'; import { AppConfigService } from '@/app-config/app-config.service'; import { ChatsRepository } from '@/chats/chats.repository'; import { AddMessageToChatRequestDto } from '@/chats/dtos/add-message-to-chat.request.dto'; diff --git a/packages/api/src/chats/usecases/remove-document-from-chat.usecase.ts b/packages/api/src/chats/usecases/remove-document-from-chat.usecase.ts index 547c55a..2d6799c 100644 --- a/packages/api/src/chats/usecases/remove-document-from-chat.usecase.ts +++ b/packages/api/src/chats/usecases/remove-document-from-chat.usecase.ts @@ -1,3 +1,4 @@ +import { AiService } from '@/ai/facades/ai.service'; import { ChatsRepository } from '@/chats/chats.repository'; import { ChatResponseDto } from '@/chats/dtos/chat.response.dto'; import { RemoveDocumentFromChatRequestDto } from '@/chats/dtos/remove-document-from-chat.request.dto'; @@ -9,7 +10,10 @@ import { Injectable } from '@nestjs/common'; @Injectable() export class RemoveDocumentFromChatUsecase implements Usecase { - constructor(private readonly chatsRepository: ChatsRepository) {} + constructor( + private readonly chatsRepository: ChatsRepository, + private readonly aiService: AiService + ) {} async execute( userId: string, @@ -33,6 +37,11 @@ export class RemoveDocumentFromChatUsecase implements Usecase { removeDocumentFromChatRequestDto ); + await this.aiService.removeVectorDBCollection( + roomId, + removeDocumentFromChatRequestDto.filename + ); + return ChatResponseSchema.parse(chat); } catch (e) { throw new InternalServerErrorException(e.message);