Skip to content

Commit

Permalink
Batching strategy call on vector stores
Browse files Browse the repository at this point in the history
- Qdrant
- Redis
- Typesense
- Weaviate
  • Loading branch information
sobychacko committed Sep 6, 2024
1 parent 0bfd011 commit 01733d7
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,12 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.qdrant;

import io.micrometer.observation.ObservationRegistry;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -32,6 +36,7 @@
* @author Anush Shetty
* @author Eddú Meléndez
* @author Christian Tzolov
* @author Soby Chacko
* @since 0.8.1
*/
@AutoConfiguration
Expand All @@ -58,14 +63,21 @@ public QdrantClient qdrantClient(QdrantVectorStoreProperties properties,
return new QdrantClient(grpcClientBuilder.build());
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public QdrantVectorStore vectorStore(EmbeddingModel embeddingModel, QdrantVectorStoreProperties properties,
QdrantClient qdrantClient, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
return new QdrantVectorStore(qdrantClient, properties.getCollectionName(), embeddingModel,
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

static class PropertiesQdrantConnectionDetails implements QdrantConnectionDetails {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.redis;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -35,18 +38,26 @@
/**
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Soby Chacko
*/
@AutoConfiguration(after = RedisAutoConfiguration.class)
@ConditionalOnClass({ JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class })
@ConditionalOnBean(JedisConnectionFactory.class)
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
public class RedisVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
JedisConnectionFactory jedisConnectionFactory, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var config = RedisVectorStoreConfig.builder()
.withIndexName(properties.getIndex())
Expand All @@ -56,7 +67,7 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt
return new RedisVectorStore(config, embeddingModel,
new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()),
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.typesense;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.TypesenseVectorStore;
import org.springframework.ai.vectorstore.TypesenseVectorStore.TypesenseVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -38,6 +41,7 @@
/**
* @author Pablo Sanchidrian Herrera
* @author Eddú Meléndez
* @author Soby Chacko
*/
@AutoConfiguration
@ConditionalOnClass({ TypesenseVectorStore.class, EmbeddingModel.class })
Expand All @@ -51,11 +55,18 @@ TypesenseVectorStoreAutoConfiguration.PropertiesTypesenseConnectionDetails types
return new TypesenseVectorStoreAutoConfiguration.PropertiesTypesenseConnectionDetails(properties);
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel embeddingModel,
TypesenseVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

TypesenseVectorStoreConfig config = TypesenseVectorStoreConfig.builder()
.withCollectionName(properties.getCollectionName())
Expand All @@ -64,7 +75,7 @@ public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel e

return new TypesenseVectorStore(typesenseClient, embeddingModel, config, properties.isInitializeSchema(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.weaviate;

import io.micrometer.observation.ObservationRegistry;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.v1.auth.exception.AuthException;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.WeaviateVectorStore;
import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig;
import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField;
Expand Down Expand Up @@ -62,11 +66,18 @@ public WeaviateClient weaviateClient(WeaviateVectorStoreProperties properties,
}
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateClient weaviateClient,
WeaviateVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

WeaviateVectorStoreConfig.Builder configBuilder = WeaviateVectorStore.WeaviateVectorStoreConfig.builder()
.withObjectClass(properties.getObjectClass())
Expand All @@ -79,7 +90,7 @@ public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateCl

return new WeaviateVectorStore(configBuilder.build(), embeddingModel, weaviateClient,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

static class PropertiesWeaviateConnectionDetails implements WeaviateConnectionDetails {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vectorstore.qdrant;

import static io.qdrant.client.PointIdFactory.id;
Expand All @@ -27,7 +28,10 @@
import java.util.concurrent.ExecutionException;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
Expand Down Expand Up @@ -58,6 +62,7 @@
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Josh Long
* @author Soby Chacko
* @since 0.8.1
*/
public class QdrantVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand All @@ -78,6 +83,8 @@ public class QdrantVectorStore extends AbstractObservationVectorStore implements

private final boolean initializeSchema;

private final BatchingStrategy batchingStrategy;

/**
* Configuration class for the QdrantVectorStore.
*
Expand Down Expand Up @@ -161,7 +168,8 @@ public QdrantVectorStore(QdrantClient qdrantClient, QdrantVectorStoreConfig conf
*/
public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, EmbeddingModel embeddingModel,
boolean initializeSchema) {
this(qdrantClient, collectionName, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null);
this(qdrantClient, collectionName, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null,
new TokenCountBatchingStrategy());
}

/**
Expand All @@ -175,7 +183,7 @@ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, Embed
*/
public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, EmbeddingModel embeddingModel,
boolean initializeSchema, ObservationRegistry observationRegistry,
VectorStoreObservationConvention customObservationConvention) {
VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {

super(observationRegistry, customObservationConvention);

Expand All @@ -187,6 +195,7 @@ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, Embed
this.embeddingModel = embeddingModel;
this.collectionName = collectionName;
this.qdrantClient = qdrantClient;
this.batchingStrategy = batchingStrategy;
}

/**
Expand All @@ -196,16 +205,17 @@ public QdrantVectorStore(QdrantClient qdrantClient, String collectionName, Embed
@Override
public void doAdd(List<Document> documents) {
try {
List<PointStruct> points = documents.stream().map(document -> {
// Compute and assign an embedding to the document.
document.setEmbedding(this.embeddingModel.embed(document));

return PointStruct.newBuilder()
// Compute and assign an embedding to the document.
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

List<PointStruct> points = documents.stream()
.map(document -> PointStruct.newBuilder()
.setId(id(UUID.fromString(document.getId())))
.setVectors(vectors(document.getEmbedding()))
.putAllPayload(toPayload(document))
.build();
}).toList();
.build())
.toList();

this.qdrantClient.upsertAsync(this.collectionName, points).get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.mistralai.MistralAiEmbeddingModel;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.observation.conventions.SpringAiKind;
Expand Down Expand Up @@ -191,8 +192,8 @@ public QdrantClient qdrantClient() {
@Bean
public VectorStore qdrantVectorStore(EmbeddingModel embeddingModel, QdrantClient qdrantClient,
ObservationRegistry observationRegistry) {
return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingModel, true, observationRegistry,
null);
return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingModel, true, observationRegistry, null,
new TokenCountBatchingStrategy());
}

@Bean
Expand Down
Loading

0 comments on commit 01733d7

Please sign in to comment.