From f68afb60129c935ead77ef9b3dc1e4901b5291af Mon Sep 17 00:00:00 2001 From: inpink Date: Sun, 29 Sep 2024 15:56:14 +0900 Subject: [PATCH] fix: include index name in OpenSearchVectorStore for similarity search - Resolved issue where index name was not being sent during similaritySearch. - Updated similaritySearch method to include index in the SearchRequest. - Implemented a test to verify that documents can be added and retrieved from two different indices using separate OpenSearchVectorStore instances. Ensured that similarity search results are correctly returned for the respective indices. --- .../ai/vectorstore/OpenSearchVectorStore.java | 2 + .../vectorstore/OpenSearchVectorStoreIT.java | 62 +++++++++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index 1135241cf4..1756ded2ff 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -60,6 +60,7 @@ * @author Soby Chacko * @author Christian Tzolov * @author Thomas Vitale + * @author inpink * @since 1.0.0 */ public class OpenSearchVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -178,6 +179,7 @@ public List similaritySearch(float[] embedding, int topK, double simil Filter.Expression filterExpression) { return similaritySearch(new org.opensearch.client.opensearch.core.SearchRequest.Builder() .query(getOpenSearchSimilarityQuery(embedding, filterExpression)) + .index(this.index) .sort(sortOptionsBuilder -> sortOptionsBuilder .score(scoreSortBuilder -> scoreSortBuilder.order(SortOrder.Desc))) .size(topK) diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java index dbb417d24e..01237347de 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchVectorStoreIT.java @@ -20,6 +20,7 @@ import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -30,6 +31,7 @@ import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -58,6 +60,7 @@ /** * @author Jemin Huh * @author Soby Chacko + * @author inpink * @since 1.0.0 */ @Testcontainers @@ -99,8 +102,11 @@ private ApplicationContextRunner getContextRunner() { @BeforeEach void cleanDatabase() { getContextRunner().run(context -> { - VectorStore vectorStore = context.getBean(VectorStore.class); + VectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); vectorStore.delete(List.of("_all")); + + VectorStore anotherVectorStore = context.getBean("anotherVectorStore", OpenSearchVectorStore.class); + anotherVectorStore.delete(List.of("_all")); }); } @@ -109,7 +115,7 @@ void cleanDatabase() { public void addAndSearchTest(String similarityFunction) { getContextRunner().run(context -> { - OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); if (!DEFAULT.equals(similarityFunction)) { vectorStore.withSimilarityFunction(similarityFunction); @@ -148,7 +154,7 @@ public void addAndSearchTest(String similarityFunction) { public void searchWithFilters(String similarityFunction) { getContextRunner().run(context -> { - OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); if (!DEFAULT.equals(similarityFunction)) { vectorStore.withSimilarityFunction(similarityFunction); @@ -246,7 +252,7 @@ public void searchWithFilters(String similarityFunction) { public void documentUpdateTest(String similarityFunction) { getContextRunner().run(context -> { - OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); if (!DEFAULT.equals(similarityFunction)) { vectorStore.withSimilarityFunction(similarityFunction); } @@ -302,7 +308,7 @@ public void documentUpdateTest(String similarityFunction) { public void searchThresholdTest(String similarityFunction) { getContextRunner().run(context -> { - OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class); + OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class); if (!DEFAULT.equals(similarityFunction)) { vectorStore.withSimilarityFunction(similarityFunction); } @@ -343,11 +349,41 @@ public void searchThresholdTest(String similarityFunction) { }); } + @Test + public void searchDocumentsInTwoIndicesTest() { + getContextRunner().run(context -> { + // given + OpenSearchVectorStore vectorStore1 = context.getBean("vectorStore", OpenSearchVectorStore.class); + OpenSearchVectorStore vectorStore2 = context.getBean("anotherVectorStore", OpenSearchVectorStore.class); + + Document docInIndex1 = new Document("1", "Document in index 1", Map.of("meta", "index1")); + Document docInIndex2 = new Document("2", "Document in index 2", Map.of("meta", "index2")); + + // when + vectorStore1.add(List.of(docInIndex1)); + vectorStore2.add(List.of(docInIndex2)); + + List resultInIndex1 = vectorStore1 + .similaritySearch(SearchRequest.query("Document in index 1").withTopK(1).withSimilarityThreshold(0)); + + List resultInIndex2 = vectorStore2 + .similaritySearch(SearchRequest.query("Document in index 2").withTopK(1).withSimilarityThreshold(0)); + + // then + assertThat(resultInIndex1).hasSize(1); + assertThat(resultInIndex1.get(0).getId()).isEqualTo(docInIndex1.getId()); + + assertThat(resultInIndex2).hasSize(1); + assertThat(resultInIndex2.get(0).getId()).isEqualTo(docInIndex2.getId()); + }); + } + @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { @Bean + @Qualifier("vectorStore") public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) { try { return new OpenSearchVectorStore(new OpenSearchClient(ApacheHttpClient5TransportBuilder @@ -359,6 +395,22 @@ public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) { } } + @Bean + @Qualifier("anotherVectorStore") + public OpenSearchVectorStore anotherVectorStore(EmbeddingModel embeddingModel) { + try { + return new OpenSearchVectorStore("another_index", + new OpenSearchClient(ApacheHttpClient5TransportBuilder + .builder(HttpHost.create(opensearchContainer.getHttpHostAddress())) + .build()), + embeddingModel, OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536, + true); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + @Bean public EmbeddingModel embeddingModel() { return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));