Skip to content

Commit

Permalink
feat: optimize chat memory (#124)
Browse files Browse the repository at this point in the history
* feat: optimize chat memory

* feat: pr comments fix

* feat: pr fixes

* feat: add types to ai.service methods
  • Loading branch information
romansharapov19 authored Aug 7, 2023
1 parent 7262a50 commit 2706f49
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 25 deletions.
2 changes: 1 addition & 1 deletion packages/api/config/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"ai": {
"defaultTemperature": 0.2,
"defaultChatContextTTL": 604800,
"defaultTokenLimitForSummarization": 15000,
"defaultTokenLimitForSummarization": 14500,
"defaultAiModel": "gpt-3.5-turbo-16k"
},
"chat": {
Expand Down
5 changes: 3 additions & 2 deletions packages/api/src/ai/services/agent-conversation.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ export class AgentConversationService {
memory: new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
chatHistory: this.memoryService.getMemory(roomId, summary)
.chatHistory,
chatHistory: (
await this.memoryService.getMemory(roomId, summary)
).chatHistory,
}),
}
);
Expand Down
24 changes: 20 additions & 4 deletions packages/api/src/ai/services/ai.service.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import { AgentConversationService } from '@/ai/services/agent-conversation.service';
import { MemoryService } from '@/ai/services/memory.service';
import { SimpleConversationChainService } from '@/ai/services/simple-conversation-chain.service';
import { AppConfigService } from '@/app-config/app-config.service';
import { RedisChatMemoryNotFoundException } from '@/chats/exceptions/redis-chat-memory-not-found.exception';
import { ChatDocument } from '@/common/types/chat';
import { Injectable } from '@nestjs/common';
import { AgentExecutor } from 'langchain/agents';
import { ConversationChain, RetrievalQAChain } from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { BaseChatMessage, ChainValues } from 'langchain/schema';
import { VectorStoreRetriever } from 'langchain/vectorstores/base';

type AIExecutor = AgentExecutor | ConversationChain;
Expand All @@ -17,7 +20,8 @@ export class AiService {
constructor(
private readonly simpleConversationChainService: SimpleConversationChainService,
private readonly agentConversationService: AgentConversationService,
private readonly appConfigService: AppConfigService
private readonly appConfigService: AppConfigService,
private readonly memoryService: MemoryService
) {
this.llmModel = new ChatOpenAI({
temperature: this.appConfigService.getAiAppConfig().defaultTemperature,
Expand Down Expand Up @@ -50,7 +54,7 @@ export class AiService {
summary
);
} else {
aiExecutor = this.simpleConversationChainService.getChain(
aiExecutor = await this.simpleConversationChainService.getChain(
roomId,
this.llmModel,
summary?.response
Expand Down Expand Up @@ -87,8 +91,20 @@ export class AiService {
};
}

private async askAiToSummarize(roomId: string) {
const chain = this.simpleConversationChainService.getChain(
async getChatHistoryMessages(roomId: string): Promise<BaseChatMessage[]> {
const redisChatMemory = await (
await this.memoryService.getMemory(roomId)
).chatHistory.getMessages();

if (redisChatMemory) {
return redisChatMemory;
}

throw new RedisChatMemoryNotFoundException();
}

private async askAiToSummarize(roomId: string): Promise<ChainValues> {
const chain = await this.simpleConversationChainService.getChain(
roomId,
this.llmModel
);
Expand Down
19 changes: 12 additions & 7 deletions packages/api/src/ai/services/memory.service.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { AppConfigService } from '@/app-config/app-config.service';
import { CACHE_CLIENT } from '@/common/constants/cache';
import { Inject, Injectable } from '@nestjs/common';
import { BufferMemory, ChatMessageHistory } from 'langchain/memory';
import { AIMessage } from 'langchain/schema';
import { BufferMemory } from 'langchain/memory';
import { RedisChatMessageHistory } from 'langchain/stores/message/redis';
import { RedisClientType } from 'redis';

Expand All @@ -18,11 +17,11 @@ export class MemoryService {
this.memoryMap = new Map<string, BufferMemory>();
}

getMemory(roomId: string, summary?: string): BufferMemory {
async getMemory(roomId: string, summary?: string): Promise<BufferMemory> {
if (!!summary) {
this.createMemoryWithSummary(roomId, summary);
await this.memoryMap.get(roomId).clear();
await this.createMemoryWithSummary(roomId, summary);
}

if (!this.hasMemory(roomId)) {
this.createMemory(roomId);
}
Expand Down Expand Up @@ -50,13 +49,19 @@ export class MemoryService {
);
}

private createMemoryWithSummary(roomId: string, summary: string) {
private async createMemoryWithSummary(roomId: string, summary: string) {
const redisChatSummary = new RedisChatMessageHistory({
sessionId: roomId,
client: this.cacheClient,
sessionTTL: this.appConfigService.getAiAppConfig().defaultChatContextTTL,
});
await redisChatSummary.addAIChatMessage(summary);
this.memoryMap.set(
roomId,
new BufferMemory({
returnMessages: true,
memoryKey: 'history',
chatHistory: new ChatMessageHistory([new AIMessage(summary)]),
chatHistory: redisChatSummary,
})
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ export class SimpleConversationChainService {
]);
}

getChain(
async getChain(
roomId: string,
llmModel: BaseChatModel,
summary?: string
): ConversationChain {
): Promise<ConversationChain> {
if (!this.hasChain(roomId) || !!summary) {
this.createChain(roomId, llmModel, summary);
await this.createChain(roomId, llmModel, summary);
}

return this.chainMap.get(roomId);
Expand All @@ -37,7 +37,7 @@ export class SimpleConversationChainService {
return this.chainMap.has(roomId);
}

private createChain(
private async createChain(
roomId: string,
llmModel: BaseChatModel,
summary?: string
Expand All @@ -47,7 +47,7 @@ export class SimpleConversationChainService {
new ConversationChain({
llm: llmModel,
prompt: this.defaultChatPrompt,
memory: this.memoryService.getMemory(roomId, summary),
memory: await this.memoryService.getMemory(roomId, summary),
})
);
}
Expand Down
1 change: 0 additions & 1 deletion packages/api/src/chats/chat-socket.gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ export class ChatSocketGateway {
}),
data.userId
);

const chat = await this.chatsRepository.findChatByRoomId(data.roomId);
const allDocumentsReadyToQuery = chat.documents.every(
(document) => document.meta.queryable
Expand Down
4 changes: 4 additions & 0 deletions packages/api/src/chats/chats.module.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { AiModule } from '@/ai/ai.module';
import { MemoryService } from '@/ai/services/memory.service';
import { AppConfigModule } from '@/app-config/app-config.module';
import { AuthModule } from '@/auth/auth.module';
import { ClerkAuthGuard } from '@/auth/guards/clerk/clerk.auth.guard';
import { CacheModule } from '@/cache/cache.module';
import { ChatSocketGateway } from '@/chats/chat-socket.gateway';
import { ChatsController } from '@/chats/chats.controller';
import { chatsMongooseProviders } from '@/chats/chats.mongoose.providers';
Expand All @@ -26,6 +28,7 @@ import { Module } from '@nestjs/common';
AiModule,
AppConfigModule,
BullModule.registerQueue({ name: CHAT_DOCUMENT_UPLOAD_QUEUE }),
CacheModule,
],
controllers: [ChatsController],
providers: [
Expand All @@ -34,6 +37,7 @@ import { Module } from '@nestjs/common';
// DB Providers
...chatsMongooseProviders,
// Services
MemoryService,
ChatsRepository,
ChatSocketGateway,
// Usecases
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { HttpException, HttpStatus } from '@nestjs/common';

export const RedisChatMemoryNotFoundExceptionSchema = {
type: 'object',
properties: {
statusCode: {
type: 'number',
example: 404,
},
message: {
type: 'array',
items: {
type: 'string',
example: 'redis_chat: not_found',
},
},
error: {
type: 'string',
example: 'Redis chat memory does not exist',
},
},
required: ['statusCode', 'message', 'error'],
};

export class RedisChatMemoryNotFoundException extends HttpException {
constructor() {
super(
{
statusCode: HttpStatus.NOT_FOUND,
message: ['redis_chat: not_found'],
error: 'Redis chat does not exist',
},
HttpStatus.NOT_FOUND
);
}
}
19 changes: 14 additions & 5 deletions packages/api/src/chats/usecases/add-message-to-chat.usecase.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { AiService } from '@/ai/services/ai.service';
import { AppConfigService } from '@/app-config/app-config.service';
import { ChatsRepository } from '@/chats/chats.repository';
import { AddMessageToChatRequestDto } from '@/chats/dtos/add-message-to-chat.request.dto';
Expand All @@ -8,12 +9,14 @@ import { InternalServerErrorException } from '@/common/exceptions/internal-serve
import { Usecase } from '@/common/types/usecase';
import { ChatMessageResponseSchema } from '@/contract/chats/chat-message.response.dto';
import { Injectable } from '@nestjs/common';
import { encode } from 'gpt-3-encoder';

@Injectable()
export class AddMessageToChatUsecase implements Usecase {
constructor(
private readonly chatsRepository: ChatsRepository,
private readonly appConfigService: AppConfigService
private readonly appConfigService: AppConfigService,
private readonly aiService: AiService
) {}

async execute(
Expand All @@ -30,11 +33,17 @@ export class AddMessageToChatUsecase implements Usecase {
throw new ChatMessageMustHaveSenderException();
}

const numberOfTokens = existingChat.messageHistory.reduce(
(acc, curr) => (acc += curr.meta.tokens),
0
const chatMessageHistory = await this.aiService.getChatHistoryMessages(
roomId
);

const numberOfTokensRedis = chatMessageHistory.reduce((acc, curr) => {
if (!curr.text) {
return 0;
}
return (acc += encode(curr.text).length);
}, 0);

try {
const chatMessage = await this.chatsRepository.addMessageToChat(
existingChat,
Expand All @@ -45,7 +54,7 @@ export class AddMessageToChatUsecase implements Usecase {
return {
message: ChatMessageResponseSchema.parse(chatMessage),
shouldSummarize:
numberOfTokens >
numberOfTokensRedis >
this.appConfigService.getAiAppConfig()
.defaultTokenLimitForSummarization,
};
Expand Down

0 comments on commit 2706f49

Please sign in to comment.