diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 53aeb23ad4..1b5a62507a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -13,11 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.pgvector; import javax.sql.DataSource; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.vectorstore.PgVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; import org.springframework.beans.factory.ObjectProvider; @@ -34,17 +37,26 @@ /** * @author Christian Tzolov * @author Josh Long + * @author Soby Chacko + * @since 1.0.0 */ @AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) @ConditionalOnClass({ PgVectorStore.class, DataSource.class, JdbcTemplate.class }) @EnableConfigurationProperties(PgVectorStoreProperties.class) public class PgVectorStoreAutoConfiguration { + @Bean + @ConditionalOnMissingBean(BatchingStrategy.class) + BatchingStrategy pgVectorStoreBatchingStrategy() { + return new TokenCountBatchingStrategy(); + } + @Bean @ConditionalOnMissingBean public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, PgVectorStoreProperties properties, ObjectProvider observationRegistry, - ObjectProvider customObservationConvention) { + ObjectProvider customObservationConvention, + BatchingStrategy batchingStrategy) { var initializeSchema = properties.isInitializeSchema(); @@ -58,6 +70,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed .withInitializeSchema(initializeSchema) .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) + .withBatchingStrategy(batchingStrategy) .build(); } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index ef385c51dd..8ab0546fc5 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -27,7 +27,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.BatchingStrategy; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.embedding.EmbeddingOptionsBuilder; +import org.springframework.ai.embedding.TokenCountBatchingStrategy; import org.springframework.ai.observation.conventions.VectorStoreProvider; import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric; import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; @@ -57,6 +60,8 @@ * @author Josh Long * @author Muthukumaran Navaneethakrishnan * @author Thomas Vitale + * @author Soby Chacko + * @since 1.0.0 */ public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean { @@ -90,17 +95,19 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final boolean initializeSchema; - private int dimensions; + private final int dimensions; - private PgDistanceType distanceType; + private final PgDistanceType distanceType; - private ObjectMapper objectMapper = new ObjectMapper(); + private final ObjectMapper objectMapper = new ObjectMapper(); - private boolean removeExistingVectorStoreTable; + private final boolean removeExistingVectorStoreTable; - private PgIndexType createIndexMethod; + private final PgIndexType createIndexMethod; - private PgVectorSchemaValidator schemaValidator; + private final PgVectorSchemaValidator schemaValidator; + + private final BatchingStrategy batchingStrategy; public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false, @@ -134,13 +141,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema, - ObservationRegistry.NOOP, null); + ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema, - ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) { + ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, + BatchingStrategy batchingStrategy) { super(observationRegistry, customObservationConvention); @@ -163,6 +171,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this.createIndexMethod = createIndexMethod; this.initializeSchema = initializeSchema; this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate); + this.batchingStrategy = batchingStrategy; } public PgDistanceType getDistanceType() { @@ -174,6 +183,8 @@ public void doAdd(List documents) { int size = documents.size(); + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + this.jdbcTemplate.batchUpdate( "INSERT INTO " + getFullyQualifiedTableName() + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " @@ -185,8 +196,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { var document = documents.get(i); var content = document.getContent(); var json = toJson(document.getMetadata()); - var embedding = embeddingModel.embed(document); - document.setEmbedding(embedding); + var embedding = document.getEmbedding(); var pGvector = new PGvector(embedding); StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, @@ -497,6 +507,8 @@ public static class Builder { private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + @Nullable private VectorStoreObservationConvention searchObservationConvention; @@ -559,10 +571,16 @@ public Builder withSearchObservationConvention(VectorStoreObservationConvention return this; } + public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { + this.batchingStrategy = batchingStrategy; + return this; + } + public PgVectorStore build() { - return new PgVectorStore(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, - embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, indexType, - initializeSchema, observationRegistry, searchObservationConvention); + return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, + this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, + this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema, + this.observationRegistry, this.searchObservationConvention, this.batchingStrategy); } }