From 202148d45bf9c226a04768f7ff9836a89e0bee9c Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Mon, 23 Sep 2024 21:37:04 -0400 Subject: [PATCH] Prevent timeouts with configurable batching for PgVectorStore inserts Resolves https://github.com/spring-projects/spring-ai/issues/1199 - Implement configurable maxDocumentBatchSize to prevent insert timeouts when adding large numbers of documents - Update PgVectorStore to process document inserts in controlled batches - Add maxDocumentBatchSize property to PgVectorStoreProperties - Update PgVectorStoreAutoConfiguration to use the new batching property - Add tests to verify batching behavior and performance This change addresses the issue of PgVectorStore inserts timing out due to large document volumes. By introducing configurable batching, users can now control the insert process to avoid timeouts while maintaining performance and reducing memory overhead for large-scale document additions. --- .../PgVectorStoreAutoConfiguration.java | 1 + .../pgvector/PgVectorStoreProperties.java | 11 ++ .../ai/vectorstore/PgVectorStore.java | 121 ++++++++++-------- .../ai/vectorstore/PgVectorStoreTests.java | 51 +++++++- 4 files changed, 132 insertions(+), 52 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 1b5a62507a..ec4d76e074 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) .withBatchingStrategy(batchingStrategy) + .withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index b455417461..47a12c36d3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -24,6 +24,7 @@ /** * @author Christian Tzolov * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ @ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX) public class PgVectorStoreProperties extends CommonVectorStoreProperties { @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION; + private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; + public int getDimensions() { return dimensions; } @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) { this.schemaValidation = schemaValidation; } + public int getMaxDocumentBatchSize() { + return this.maxDocumentBatchSize; + } + + public void setMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + } + } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 8ab0546fc5..697960f15d 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -15,14 +15,10 @@ */ package org.springframework.ai.vectorstore; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; - +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pgvector.PGvector; +import io.micrometer.observation.ObservationRegistry; import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,11 +42,14 @@ import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.pgvector.PGvector; - -import io.micrometer.observation.ObservationRegistry; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; /** * Uses the "vector_store" table to store the Spring AI vector data. The table and the @@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private final String vectorTableName; private final String vectorIndexName; @@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final BatchingStrategy batchingStrategy; + private final int maxDocumentBatchSize; + public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false, PgIndexType.NONE, false); @@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema); - } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema, - ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); + ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE); } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, - BatchingStrategy batchingStrategy) { + BatchingStrategy batchingStrategy, int maxDocumentBatchSize) { super(observationRegistry, customObservationConvention); @@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this.initializeSchema = initializeSchema; this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate); this.batchingStrategy = batchingStrategy; + this.maxDocumentBatchSize = maxDocumentBatchSize; } public PgDistanceType getDistanceType() { @@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() { @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - int size = documents.size(); + List> batchedDocuments = batchDocuments(documents); + batchedDocuments.forEach(this::insertOrUpdateBatch); + } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + private List> batchDocuments(List documents) { + List> batches = new ArrayList<>(); + for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) { + batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size()))); + } + return batches; + } - this.jdbcTemplate.batchUpdate( - "INSERT INTO " + getFullyQualifiedTableName() - + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " - + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - - var document = documents.get(i); - var content = document.getContent(); - var json = toJson(document.getMetadata()); - var embedding = document.getEmbedding(); - var pGvector = new PGvector(embedding); - - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); - StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); - StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); - } + private void insertOrUpdateBatch(List batch) { + String sql = "INSERT INTO " + getFullyQualifiedTableName() + + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + + var document = batch.get(i); + var content = document.getContent(); + var json = toJson(document.getMetadata()); + var embedding = document.getEmbedding(); + var pGvector = new PGvector(embedding); + + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, + UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); + StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); + } - @Override - public int getBatchSize() { - return size; - } - }); + @Override + public int getBatchSize() { + return batch.size(); + } + }); } private String toJson(Map map) { @@ -285,7 +298,7 @@ private String comparisonOperator() { // Initialize // --------------------------------------------------------------------------------- @Override - public void afterPropertiesSet() throws Exception { + public void afterPropertiesSet() { logger.info("Initializing PGVectorStore schema for table: {} in schema: {}", this.getVectorTableName(), this.getSchemaName()); @@ -390,7 +403,7 @@ public enum PgIndexType { * speed-recall tradeoff). There’s no training step like IVFFlat, so the index can * be created without any data in the table. */ - HNSW; + HNSW } @@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper { private static final String COLUMN_DISTANCE = "distance"; - private ObjectMapper objectMapper; + private final ObjectMapper objectMapper; public DocumentRowMapper(ObjectMapper objectMapper) { this.objectMapper = objectMapper; @@ -509,6 +522,8 @@ public static class Builder { private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE; + @Nullable private VectorStoreObservationConvention searchObservationConvention; @@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { return this; } + public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + return this; + } + public PgVectorStore build() { return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema, - this.observationRegistry, this.searchObservationConvention, this.batchingStrategy); + this.observationRegistry, this.searchObservationConvention, this.batchingStrategy, + this.maxDocumentBatchSize); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index f5b69a922c..488dbd3f73 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -15,15 +15,31 @@ */ package org.springframework.ai.vectorstore; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Collections; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; /** * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ - public class PgVectorStoreTests { @ParameterizedTest(name = "{0} - Verifies valid Table name") @@ -53,8 +69,39 @@ public class PgVectorStoreTests { // 64 // characters }) - public void isValidTable(String tableName, Boolean expected) { + void isValidTable(String tableName, Boolean expected) { assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected); } + @Test + void shouldAddDocumentsInBatchesAndEmbedOnce() { + // Given + var jdbcTemplate = mock(JdbcTemplate.class); + var embeddingModel = mock(EmbeddingModel.class); + var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000) + .build(); + + // Testing with 9989 documents + var documents = Collections.nCopies(9989, new Document("foo")); + + // When + pgVectorStore.doAdd(documents); + + // Then + verify(embeddingModel, only()).embed(eq(documents), any(), any()); + + var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); + verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); + + assertThat(batchUpdateCaptor.getAllValues()).hasSize(10) + .allSatisfy(BatchPreparedStatementSetter::getBatchSize) + .satisfies(batches -> { + for (int i = 0; i < 9; i++) { + assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i) + .isEqualTo(1000); + } + assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989); + }); + } + }