Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching strategy for more vector stores #1318

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,10 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.chroma;

import org.springframework.ai.chroma.ChromaApi;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.ChromaVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -36,6 +39,7 @@
/**
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Soby Chacko
*/
@AutoConfiguration
@ConditionalOnClass({ EmbeddingModel.class, RestClient.class, ChromaVectorStore.class, ObjectMapper.class })
Expand Down Expand Up @@ -73,14 +77,21 @@ else if (StringUtils.hasText(apiProperties.getUsername()) && StringUtils.hasText
return chromaApi;
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy chromaBatchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public ChromaVectorStore vectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi,
ChromaVectorStoreProperties storeProperties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy chromaBatchingStrategy) {
return new ChromaVectorStore(embeddingModel, chromaApi, storeProperties.getCollectionName(),
storeProperties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), chromaBatchingStrategy);
}

static class PropertiesChromaConnectionDetails implements ChromaConnectionDetails {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,11 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.elasticsearch;

import org.elasticsearch.client.RestClient;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.ElasticsearchVectorStore;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -37,19 +40,26 @@
* @author Wei Jiang
* @author Josh Long
* @author Christian Tzolov
* @author Soby Chacko
* @since 1.0.0
*/

@AutoConfiguration(after = ElasticsearchRestClientAutoConfiguration.class)
@ConditionalOnClass({ ElasticsearchVectorStore.class, EmbeddingModel.class, RestClient.class })
@EnableConfigurationProperties(ElasticsearchVectorStoreProperties.class)
class ElasticsearchVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properties, RestClient restClient,
EmbeddingModel embeddingModel, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
ElasticsearchVectorStoreOptions elasticsearchVectorStoreOptions = new ElasticsearchVectorStoreOptions();

if (StringUtils.hasText(properties.getIndexName())) {
Expand All @@ -64,7 +74,7 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti

return new ElasticsearchVectorStore(elasticsearchVectorStoreOptions, restClient, embeddingModel,
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,11 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.neo4j;

import org.neo4j.driver.Driver;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.Neo4jVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -34,17 +37,25 @@
* @author Jingzhou Ou
* @author Josh Long
* @author Christian Tzolov
* @author Soby Chacko
*/
@AutoConfiguration(after = Neo4jAutoConfiguration.class)
@ConditionalOnClass({ Neo4jVectorStore.class, EmbeddingModel.class, Driver.class })
@EnableConfigurationProperties({ Neo4jVectorStoreProperties.class })
public class Neo4jVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel,
Neo4jVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
Neo4jVectorStore.Neo4jVectorStoreConfig config = Neo4jVectorStore.Neo4jVectorStoreConfig.builder()
.withDatabaseName(properties.getDatabaseName())
.withEmbeddingDimension(properties.getEmbeddingDimension())
Expand All @@ -58,7 +69,7 @@ public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel

return new Neo4jVectorStore(driver, embeddingModel, config, properties.isInitializeSchema(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,12 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.qdrant;

import io.micrometer.observation.ObservationRegistry;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -32,6 +36,7 @@
* @author Anush Shetty
* @author Eddú Meléndez
* @author Christian Tzolov
* @author Soby Chacko
* @since 0.8.1
*/
@AutoConfiguration
Expand All @@ -58,14 +63,21 @@ public QdrantClient qdrantClient(QdrantVectorStoreProperties properties,
return new QdrantClient(grpcClientBuilder.build());
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public QdrantVectorStore vectorStore(EmbeddingModel embeddingModel, QdrantVectorStoreProperties properties,
QdrantClient qdrantClient, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {
return new QdrantVectorStore(qdrantClient, properties.getCollectionName(), embeddingModel,
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

static class PropertiesQdrantConnectionDetails implements QdrantConnectionDetails {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.redis;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -35,18 +38,26 @@
/**
* @author Christian Tzolov
* @author Eddú Meléndez
* @author Soby Chacko
*/
@AutoConfiguration(after = RedisAutoConfiguration.class)
@ConditionalOnClass({ JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class })
@ConditionalOnBean(JedisConnectionFactory.class)
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
public class RedisVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
JedisConnectionFactory jedisConnectionFactory, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var config = RedisVectorStoreConfig.builder()
.withIndexName(properties.getIndex())
Expand All @@ -56,7 +67,7 @@ public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorSt
return new RedisVectorStore(config, embeddingModel,
new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()),
properties.isInitializeSchema(), observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.typesense;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.TypesenseVectorStore;
import org.springframework.ai.vectorstore.TypesenseVectorStore.TypesenseVectorStoreConfig;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
Expand All @@ -38,6 +41,7 @@
/**
* @author Pablo Sanchidrian Herrera
* @author Eddú Meléndez
* @author Soby Chacko
*/
@AutoConfiguration
@ConditionalOnClass({ TypesenseVectorStore.class, EmbeddingModel.class })
Expand All @@ -51,11 +55,18 @@ TypesenseVectorStoreAutoConfiguration.PropertiesTypesenseConnectionDetails types
return new TypesenseVectorStoreAutoConfiguration.PropertiesTypesenseConnectionDetails(properties);
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel embeddingModel,
TypesenseVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

TypesenseVectorStoreConfig config = TypesenseVectorStoreConfig.builder()
.withCollectionName(properties.getCollectionName())
Expand All @@ -64,7 +75,7 @@ public TypesenseVectorStore vectorStore(Client typesenseClient, EmbeddingModel e

return new TypesenseVectorStore(typesenseClient, embeddingModel, config, properties.isInitializeSchema(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.weaviate;

import io.micrometer.observation.ObservationRegistry;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateAuthClient;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.v1.auth.exception.AuthException;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.WeaviateVectorStore;
import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig;
import org.springframework.ai.vectorstore.WeaviateVectorStore.WeaviateVectorStoreConfig.MetadataField;
Expand Down Expand Up @@ -62,11 +66,18 @@ public WeaviateClient weaviateClient(WeaviateVectorStoreProperties properties,
}
}

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateClient weaviateClient,
WeaviateVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

WeaviateVectorStoreConfig.Builder configBuilder = WeaviateVectorStore.WeaviateVectorStoreConfig.builder()
.withObjectClass(properties.getObjectClass())
Expand All @@ -79,7 +90,7 @@ public WeaviateVectorStore vectorStore(EmbeddingModel embeddingModel, WeaviateCl

return new WeaviateVectorStore(configBuilder.build(), embeddingModel, weaviateClient,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
customObservationConvention.getIfAvailable(() -> null));
customObservationConvention.getIfAvailable(() -> null), batchingStrategy);
}

static class PropertiesWeaviateConnectionDetails implements WeaviateConnectionDetails {
Expand Down
Loading