Skip to content

Commit

Permalink
CB-13 Fix: Track conversations and graphs on SendMessageCommand
Browse files Browse the repository at this point in the history
  • Loading branch information
izzat5233 committed Dec 30, 2024
1 parent 0df7f9a commit 5dbfb02
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,38 @@ public interface IConversationRepository
void Update(Conversation conversation);
void Delete(Conversation conversation);

/// <remarks>
/// Includes the chatbot graph.
/// </remarks>
Task<Chatbot?> GetChatbotByIdIfAuthorizedAsync(
ChatbotId chatbotId,
UserId userId,
CancellationToken cancellationToken);

/// <remarks>
/// Does not include messages.
/// </remarks>
Task<Conversation?> GetByIdAndUserAsync(
ConversationId conversationId,
UserId userId,
CancellationToken cancellationToken);

/// <remarks>
/// Use when you want to include and track all changes and new messages.
/// </remarks>
Task<Conversation?> LoadByIdAndUserAsync(
ConversationId conversationId,
UserId userId,
CancellationToken cancellationToken);

/// <remarks>
/// Loads and tracks the graph for the conversation.
/// Use only when you want to update the conversation with new messages.
/// </remarks>
Task<Graph?> LoadGraphAsync(
ConversationId conversationId,
CancellationToken cancellationToken);

Task<PageResponse<ListConversationResponseItem>> ListByQueryAsync(
ListConversationsQuery query,
CancellationToken cancellationToken);
Expand All @@ -36,10 +58,6 @@ Task<ListMessagesResponse> ListMessagesAsync(
PageParams pageParams,
CancellationToken cancellationToken);

Task<Graph?> GetGraphAsync(
ConversationId conversationId,
CancellationToken cancellationToken);

Task<List<ConversationId>> ListByChatbotIdAsync(
ChatbotId chatbotId,
CancellationToken cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public async Task<Result<SendMessageResponse>> Handle(
SendMessageCommand request,
CancellationToken cancellationToken)
{
var conversation = await _repository.GetByIdAndUserAsync(
var conversation = await _repository.LoadByIdAndUserAsync(
request.ConversationId,
request.UserId,
cancellationToken);
Expand All @@ -35,7 +35,7 @@ public async Task<Result<SendMessageResponse>> Handle(
return Result.Failure<SendMessageResponse>(ConversationsApplicationErrors.ConversationNotFound);
}

var graph = (await _repository.GetGraphAsync(conversation.Id, cancellationToken))!;
var graph = (await _repository.LoadGraphAsync(conversation.Id, cancellationToken))!;

_conversationFlowService.GraphTraversalService.Graph = graph;
_conversationFlowService.Conversation = conversation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,51 @@ public void Add(Conversation conversation, Graph conversationGraph)
.FirstOrDefaultAsync(cancellationToken);
}

public async Task<Conversation?> GetByIdAndUserAsync(
private IQueryable<Conversation> GetByIdAndUser(
ConversationId conversationId,
UserId userId,
CancellationToken cancellationToken)
UserId userId)
{
return await Context.Set<Conversation>()
return Context.Set<Conversation>()
.AsTracking()
.Where(c =>
c.Id == conversationId &&
Context.Set<Workflow>()
.First(w => w.Id == Context.Set<Chatbot>()
.First(cb => cb.Id == c.ChatbotId)
.WorkflowId)
.OwnerId == userId)
.OwnerId == userId);
}

public async Task<Conversation?> GetByIdAndUserAsync(
ConversationId conversationId,
UserId userId,
CancellationToken cancellationToken)
{
return await GetByIdAndUser(conversationId, userId)
.FirstOrDefaultAsync(cancellationToken);
}

public async Task<Conversation?> LoadByIdAndUserAsync(
ConversationId conversationId,
UserId userId,
CancellationToken cancellationToken)
{
return await GetByIdAndUser(conversationId, userId)
.AsTracking()
.FirstOrDefaultAsync(cancellationToken);
}

public async Task<Graph?> LoadGraphAsync(
ConversationId conversationId,
CancellationToken cancellationToken)
{
return await Context.Set<Graph>()
.AsTracking()
.AsSplitQuery()
.Where(g =>
g.Id == Context.Set<Conversation>()
.First(c => c.Id == conversationId)
.GraphId)
.FirstOrDefaultAsync(cancellationToken);
}

Expand Down Expand Up @@ -110,19 +142,6 @@ public async Task<ListMessagesResponse> ListMessagesAsync(
return new ListMessagesResponse(inputMessages, outputMessages);
}

public async Task<Graph?> GetGraphAsync(
ConversationId conversationId,
CancellationToken cancellationToken)
{
return await Context.Set<Graph>()
.AsSplitQuery()
.Where(g =>
g.Id == Context.Set<Conversation>()
.First(c => c.Id == conversationId)
.GraphId)
.FirstOrDefaultAsync(cancellationToken);
}

public async Task<List<ConversationId>> ListByChatbotIdAsync(
ChatbotId chatbotId,
CancellationToken cancellationToken)
Expand Down

0 comments on commit 5dbfb02

Please sign in to comment.