Skip to content

Commit

Permalink
Prevent timeouts with configurable batching for PgVectorStore inserts
Browse files Browse the repository at this point in the history
Resolves #1199

- Implement configurable maxDocumentBatchSize to prevent insert timeouts
  when adding large numbers of documents
- Update PgVectorStore to process document inserts in controlled batches
- Add maxDocumentBatchSize property to PgVectorStoreProperties
- Update PgVectorStoreAutoConfiguration to use the new batching property
- Add tests to verify batching behavior and performance

This change addresses the issue of PgVectorStore inserts timing out due to
large document volumes. By introducing configurable batching, users can now
control the insert process to avoid timeouts while maintaining performance
and reducing memory overhead for large-scale document additions.
  • Loading branch information
sobychacko authored and Mark Pollack committed Sep 24, 2024
1 parent 42dcb45 commit 202148d
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.withBatchingStrategy(batchingStrategy)
.withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
/**
* @author Christian Tzolov
* @author Muthukumaran Navaneethakrishnan
* @author Soby Chacko
*/
@ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX)
public class PgVectorStoreProperties extends CommonVectorStoreProperties {
Expand All @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties {

private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;

private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE;

public int getDimensions() {
return dimensions;
}
Expand Down Expand Up @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) {
this.schemaValidation = schemaValidation;
}

public int getMaxDocumentBatchSize() {
return this.maxDocumentBatchSize;
}

public void setMaxDocumentBatchSize(int maxDocumentBatchSize) {
this.maxDocumentBatchSize = maxDocumentBatchSize;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
*/
package org.springframework.ai.vectorstore;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import io.micrometer.observation.ObservationRegistry;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -46,11 +42,14 @@
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;

import io.micrometer.observation.ObservationRegistry;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

/**
* Uses the "vector_store" table to store the Spring AI vector data. The table and the
Expand Down Expand Up @@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter();

public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000;

private final String vectorTableName;

private final String vectorIndexName;
Expand Down Expand Up @@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final BatchingStrategy batchingStrategy;

private final int maxDocumentBatchSize;

public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
PgIndexType.NONE, false);
Expand All @@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin

this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema);

}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
Expand All @@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT

this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE);
}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
BatchingStrategy batchingStrategy) {
BatchingStrategy batchingStrategy, int maxDocumentBatchSize) {

super(observationRegistry, customObservationConvention);

Expand All @@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
this.initializeSchema = initializeSchema;
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
this.batchingStrategy = batchingStrategy;
this.maxDocumentBatchSize = maxDocumentBatchSize;
}

public PgDistanceType getDistanceType() {
Expand All @@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() {

@Override
public void doAdd(List<Document> documents) {
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

int size = documents.size();
List<List<Document>> batchedDocuments = batchDocuments(documents);
batchedDocuments.forEach(this::insertOrUpdateBatch);
}

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
private List<List<Document>> batchDocuments(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) {
batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size())));
}
return batches;
}

this.jdbcTemplate.batchUpdate(
"INSERT INTO " + getFullyQualifiedTableName()
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ",
new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {

var document = documents.get(i);
var content = document.getContent();
var json = toJson(document.getMetadata());
var embedding = document.getEmbedding();
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
}
private void insertOrUpdateBatch(List<Document> batch) {
String sql = "INSERT INTO " + getFullyQualifiedTableName()
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ";

this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {

var document = batch.get(i);
var content = document.getContent();
var json = toJson(document.getMetadata());
var embedding = document.getEmbedding();
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
}

@Override
public int getBatchSize() {
return size;
}
});
@Override
public int getBatchSize() {
return batch.size();
}
});
}

private String toJson(Map<String, Object> map) {
Expand Down Expand Up @@ -285,7 +298,7 @@ private String comparisonOperator() {
// Initialize
// ---------------------------------------------------------------------------------
@Override
public void afterPropertiesSet() throws Exception {
public void afterPropertiesSet() {

logger.info("Initializing PGVectorStore schema for table: {} in schema: {}", this.getVectorTableName(),
this.getSchemaName());
Expand Down Expand Up @@ -390,7 +403,7 @@ public enum PgIndexType {
* speed-recall tradeoff). There’s no training step like IVFFlat, so the index can
* be created without any data in the table.
*/
HNSW;
HNSW

}

Expand Down Expand Up @@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper<Document> {

private static final String COLUMN_DISTANCE = "distance";

private ObjectMapper objectMapper;
private final ObjectMapper objectMapper;

public DocumentRowMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
Expand Down Expand Up @@ -509,6 +522,8 @@ public static class Builder {

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE;

@Nullable
private VectorStoreObservationConvention searchObservationConvention;

Expand Down Expand Up @@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
return this;
}

public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) {
this.maxDocumentBatchSize = maxDocumentBatchSize;
return this;
}

public PgVectorStore build() {
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy,
this.maxDocumentBatchSize);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,31 @@
*/
package org.springframework.ai.vectorstore;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.Collections;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;

/**
* @author Muthukumaran Navaneethakrishnan
* @author Soby Chacko
*/

public class PgVectorStoreTests {

@ParameterizedTest(name = "{0} - Verifies valid Table name")
Expand Down Expand Up @@ -53,8 +69,39 @@ public class PgVectorStoreTests {
// 64
// characters
})
public void isValidTable(String tableName, Boolean expected) {
void isValidTable(String tableName, Boolean expected) {
assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected);
}

@Test
void shouldAddDocumentsInBatchesAndEmbedOnce() {
// Given
var jdbcTemplate = mock(JdbcTemplate.class);
var embeddingModel = mock(EmbeddingModel.class);
var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000)
.build();

// Testing with 9989 documents
var documents = Collections.nCopies(9989, new Document("foo"));

// When
pgVectorStore.doAdd(documents);

// Then
verify(embeddingModel, only()).embed(eq(documents), any(), any());

var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class);
verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture());

assertThat(batchUpdateCaptor.getAllValues()).hasSize(10)
.allSatisfy(BatchPreparedStatementSetter::getBatchSize)
.satisfies(batches -> {
for (int i = 0; i < 9; i++) {
assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i)
.isEqualTo(1000);
}
assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989);
});
}

}

0 comments on commit 202148d

Please sign in to comment.