Skip to content

Commit

Permalink
Stabilize the vector store ITs
Browse files Browse the repository at this point in the history
 - Ensures that the vector store similarity search by threshold tests
   use dynamically computed threshold that is between the top 2 results
   from the ordered search. This ensures that the threshold value is not
   affected by changes in the embedding API results.

 - Minor code style improvments.
  • Loading branch information
tzolov committed Sep 28, 2023
1 parent a45e34e commit 12e3c7b
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,15 @@ public List<Double> embed(Document document) {
}

private List<Double> extractEmbeddingsList(Embeddings embeddings) {
return embeddings.getData()
.stream()
.map(EmbeddingItem::getEmbedding)
.flatMap(List::stream)
.collect(Collectors.toList());
return embeddings.getData().stream().map(EmbeddingItem::getEmbedding).flatMap(List::stream).toList();
}

@Override
public List<List<Double>> embed(List<String> texts) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts));
logger.debug("Embeddings retrieved");
return embeddings.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList());
return embeddings.getData().stream().map(emb -> emb.getEmbedding()).toList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public List<Double> embed(Document document) {

public List<List<Double>> embed(List<String> texts) {
EmbeddingResponse embeddingResponse = embedForResponse(texts);
return embeddingResponse.getData().stream().map(emb -> emb.getEmbedding()).collect(Collectors.toList());
return embeddingResponse.getData().stream().map(emb -> emb.getEmbedding()).toList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ void createCollection() {
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withDatabaseName(this.config.databaseName)
.withCollectionName(this.config.collectionName)
// .withDatabaseName(this.collectionName)
.withDescription("Spring AI Vector Store")
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
.withShardsNum(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ public void searchThresholdTest(String metricType) {

assertThat(distances).hasSize(3);

List<Document> results = vectorStore.similaritySearch("Great", 5, (1 - (distances.get(0) + 0.001)));
float threshold = (distances.get(0) + distances.get(1)) / 2;

List<Document> results = vectorStore.similaritySearch("Great", 5, (1 - threshold));

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ private static float[] toFloatArray(List<Double> embeddingDouble) {

private static Document recordToDocument(org.neo4j.driver.Record neoRecord) {
var node = neoRecord.get("node").asNode();
var score = neoRecord.get("score").asFloat();
var metaData = new HashMap<String, Object>();
metaData.put("distance", 1 - score);
node.keys().forEach(key -> {
if (key.startsWith("metadata.")) {
metaData.put(key.substring(key.indexOf(".") + 1), node.get(key).asObject());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package org.springframework.ai.vectorstore;

import java.util.Collections;
import java.util.List;
import java.util.UUID;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
Expand All @@ -14,15 +23,6 @@
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -72,10 +72,11 @@ void addAndSearchTest() {
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getText()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).isEqualTo(Collections.singletonMap("meta2", "meta2"));
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");

// Remove all documents from the store
vectorStore.delete(this.documents.stream().map(Document::getId).collect(Collectors.toList()));
vectorStore.delete(this.documents.stream().map(Document::getId).toList());

List<Document> results2 = vectorStore.similaritySearch("Great", 1);
assertThat(results2).isEmpty();
Expand All @@ -100,7 +101,8 @@ void documentUpdateTest() {
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getText()).isEqualTo("Spring AI rocks!!");
assertThat(resultDoc.getMetadata()).isEqualTo(Collections.singletonMap("meta1", "meta1"));
assertThat(resultDoc.getMetadata()).containsKey("meta1");
assertThat(resultDoc.getMetadata()).containsKey("distance");

Document sameIdDocument = new Document(document.getId(),
"The World is Big and Salvation Lurks Around the Corner",
Expand All @@ -114,7 +116,8 @@ void documentUpdateTest() {
resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(document.getId());
assertThat(resultDoc.getText()).isEqualTo("The World is Big and Salvation Lurks Around the Corner");
assertThat(resultDoc.getMetadata()).isEqualTo(Collections.singletonMap("meta2", "meta2"));
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");

});
}
Expand All @@ -128,16 +131,23 @@ void searchThresholdTest() {

vectorStore.add(this.documents);

assertThat(vectorStore.similaritySearch("Great", 5, 0)).hasSize(3);
List<Document> fullResult = vectorStore.similaritySearch("Great", 5, 0);

List<Float> distances = fullResult.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();

assertThat(distances).hasSize(3);

float threshold = (distances.get(0) + distances.get(1)) / 2;

List<Document> results = vectorStore.similaritySearch("Great", 5, 0.89);
List<Document> results = vectorStore.similaritySearch("Great", 5, 1 - threshold);

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getText()).isEqualTo(
"Great Depression Great Depression Great Depression Great Depression Great Depression Great Depression");
assertThat(resultDoc.getMetadata()).isEqualTo(Collections.singletonMap("meta2", "meta2"));
assertThat(resultDoc.getMetadata()).containsKey("meta2");
assertThat(resultDoc.getMetadata()).containsKey("distance");

});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -158,7 +157,7 @@ private List<Double> toDoubleList(PGobject embedding) throws SQLException {
List<Double> doubleEmbedding = IntStream.range(0, floatArray.length)
.mapToDouble(i -> floatArray[i])
.boxed()
.collect(Collectors.toList());
.toList();
return doubleEmbedding;

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;

import javax.sql.DataSource;

Expand Down Expand Up @@ -102,7 +101,7 @@ public void addAndSearchTest(String distanceType) {
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");

// Remove all documents from the store
vectorStore.delete(documents.stream().map(doc -> doc.getId()).collect(Collectors.toList()));
vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());

List<Document> results2 = vectorStore.similaritySearch("Great", 1);
assertThat(results2).hasSize(0);
Expand Down Expand Up @@ -165,17 +164,17 @@ public void searchThresholdTest(String distanceType) {

List<Float> distances = fullResult.stream()
.map(doc -> (Float) doc.getMetadata().get("distance"))
.collect(Collectors.toList());
.toList();

assertThat(fullResult).hasSize(3);

assertThat(isSortedByDistance(fullResult)).isTrue();

fullResult.stream().forEach(doc -> System.out.println(doc.getMetadata().get("distance")));

List<Double> embeddingDistance = ((PgVectorStore) vectorStore).embeddingDistance("Great");
float threshold = (distances.get(0) + distances.get(1)) / 2;

List<Document> results = vectorStore.similaritySearch("Great", 5, (1 - (distances.get(0) + 0.01)));
List<Document> results = vectorStore.similaritySearch("Great", 5, (1 - threshold));

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
Expand All @@ -189,9 +188,7 @@ public void searchThresholdTest(String distanceType) {

private static boolean isSortedByDistance(List<Document> docs) {

List<Float> distances = docs.stream()
.map(doc -> (Float) doc.getMetadata().get("distance"))
.collect(Collectors.toList());
List<Float> distances = docs.stream().map(doc -> (Float) doc.getMetadata().get("distance")).toList();

if (CollectionUtils.isEmpty(distances) || distances.size() == 1) {
return true;
Expand Down

0 comments on commit 12e3c7b

Please sign in to comment.