Skip to content

Commit

Permalink
Fix PgVectorStoreWithChatMemoryAdvisorIT from batching changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sobychacko committed Sep 5, 2024
1 parent 2c1f36c commit 53721c6
Showing 1 changed file with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,13 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vectorstore;

import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.postgresql.ds.PGSimpleDataSource;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
Expand All @@ -44,6 +49,7 @@

/**
* @author Fabian Krüger
* @author Soby Chacko
*/
@Testcontainers
class PgVectorStoreWithChatMemoryAdvisorIT {
Expand Down Expand Up @@ -117,9 +123,16 @@ private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingMode
return new JdbcTemplate(ds);
}

@SuppressWarnings("unchecked")
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);
when(embeddingModel.embed(any(Document.class))).thenReturn(embed);

Mockito.doAnswer(invocationOnMock -> {
Object[] arguments = invocationOnMock.getArguments();
List<Document> documents = (List<Document>) arguments[0];
documents.forEach(d -> d.setEmbedding(embed));
return List.of(embed, embed);
}).when(embeddingModel).embed(ArgumentMatchers.any(), any(), any());
when(embeddingModel.embed(any(String.class))).thenReturn(embed);
return embeddingModel;
}
Expand Down

0 comments on commit 53721c6

Please sign in to comment.