Skip to content

Commit

Permalink
fix: include index name in OpenSearchVectorStore for similarity search
Browse files Browse the repository at this point in the history
- Resolved issue where index name was not being sent during similaritySearch.
- Updated similaritySearch method to include index in the SearchRequest.
- Implemented a test to verify that documents can be added and retrieved from two different indices using separate OpenSearchVectorStore instances. Ensured that similarity search results are correctly returned for the respective indices.
  • Loading branch information
inpink committed Sep 29, 2024
1 parent 4c8a6ee commit f68afb6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
* @author Soby Chacko
* @author Christian Tzolov
* @author Thomas Vitale
* @author inpink
* @since 1.0.0
*/
public class OpenSearchVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand Down Expand Up @@ -178,6 +179,7 @@ public List<Document> similaritySearch(float[] embedding, int topK, double simil
Filter.Expression filterExpression) {
return similaritySearch(new org.opensearch.client.opensearch.core.SearchRequest.Builder()
.query(getOpenSearchSimilarityQuery(embedding, filterExpression))
.index(this.index)
.sort(sortOptionsBuilder -> sortOptionsBuilder
.score(scoreSortBuilder -> scoreSortBuilder.order(SortOrder.Desc)))
.size(topK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.awaitility.Awaitility;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
Expand All @@ -30,6 +31,7 @@
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
Expand Down Expand Up @@ -58,6 +60,7 @@
/**
* @author Jemin Huh
* @author Soby Chacko
* @author inpink
* @since 1.0.0
*/
@Testcontainers
Expand Down Expand Up @@ -99,8 +102,11 @@ private ApplicationContextRunner getContextRunner() {
@BeforeEach
void cleanDatabase() {
getContextRunner().run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
VectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
vectorStore.delete(List.of("_all"));

VectorStore anotherVectorStore = context.getBean("anotherVectorStore", OpenSearchVectorStore.class);
anotherVectorStore.delete(List.of("_all"));
});
}

Expand All @@ -109,7 +115,7 @@ void cleanDatabase() {
public void addAndSearchTest(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);

if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
Expand Down Expand Up @@ -148,7 +154,7 @@ public void addAndSearchTest(String similarityFunction) {
public void searchWithFilters(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);

if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
Expand Down Expand Up @@ -246,7 +252,7 @@ public void searchWithFilters(String similarityFunction) {
public void documentUpdateTest(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
}
Expand Down Expand Up @@ -302,7 +308,7 @@ public void documentUpdateTest(String similarityFunction) {
public void searchThresholdTest(String similarityFunction) {

getContextRunner().run(context -> {
OpenSearchVectorStore vectorStore = context.getBean(OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore = context.getBean("vectorStore", OpenSearchVectorStore.class);
if (!DEFAULT.equals(similarityFunction)) {
vectorStore.withSimilarityFunction(similarityFunction);
}
Expand Down Expand Up @@ -343,11 +349,41 @@ public void searchThresholdTest(String similarityFunction) {
});
}

@Test
public void searchDocumentsInTwoIndicesTest() {
getContextRunner().run(context -> {
// given
OpenSearchVectorStore vectorStore1 = context.getBean("vectorStore", OpenSearchVectorStore.class);
OpenSearchVectorStore vectorStore2 = context.getBean("anotherVectorStore", OpenSearchVectorStore.class);

Document docInIndex1 = new Document("1", "Document in index 1", Map.of("meta", "index1"));
Document docInIndex2 = new Document("2", "Document in index 2", Map.of("meta", "index2"));

// when
vectorStore1.add(List.of(docInIndex1));
vectorStore2.add(List.of(docInIndex2));

List<Document> resultInIndex1 = vectorStore1
.similaritySearch(SearchRequest.query("Document in index 1").withTopK(1).withSimilarityThreshold(0));

List<Document> resultInIndex2 = vectorStore2
.similaritySearch(SearchRequest.query("Document in index 2").withTopK(1).withSimilarityThreshold(0));

// then
assertThat(resultInIndex1).hasSize(1);
assertThat(resultInIndex1.get(0).getId()).isEqualTo(docInIndex1.getId());

assertThat(resultInIndex2).hasSize(1);
assertThat(resultInIndex2.get(0).getId()).isEqualTo(docInIndex2.getId());
});
}

@SpringBootConfiguration
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
public static class TestApplication {

@Bean
@Qualifier("vectorStore")
public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) {
try {
return new OpenSearchVectorStore(new OpenSearchClient(ApacheHttpClient5TransportBuilder
Expand All @@ -359,6 +395,22 @@ public OpenSearchVectorStore vectorStore(EmbeddingModel embeddingModel) {
}
}

@Bean
@Qualifier("anotherVectorStore")
public OpenSearchVectorStore anotherVectorStore(EmbeddingModel embeddingModel) {
try {
return new OpenSearchVectorStore("another_index",
new OpenSearchClient(ApacheHttpClient5TransportBuilder
.builder(HttpHost.create(opensearchContainer.getHttpHostAddress()))
.build()),
embeddingModel, OpenSearchVectorStore.DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536,
true);
}
catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
Expand Down

0 comments on commit f68afb6

Please sign in to comment.