Skip to content

Commit

Permalink
Add batching strategy for embedding documents in PgVectorStore
Browse files Browse the repository at this point in the history
- Precompute all embeddings using a BatchingStrategy before inserting into the vector store

This optimization improves efficiency when adding multiple documents

Related to #1261
  • Loading branch information
sobychacko authored and markpollack committed Sep 4, 2024
1 parent 73d0b30 commit 087de16
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.pgvector;

import javax.sql.DataSource;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -34,17 +37,26 @@
/**
* @author Christian Tzolov
* @author Josh Long
* @author Soby Chacko
* @since 1.0.0
*/
@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class)
@ConditionalOnClass({ PgVectorStore.class, DataSource.class, JdbcTemplate.class })
@EnableConfigurationProperties(PgVectorStoreProperties.class)
public class PgVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy pgVectorStoreBatchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel,
PgVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var initializeSchema = properties.isInitializeSchema();

Expand All @@ -58,6 +70,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
.withInitializeSchema(initializeSchema)
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.withBatchingStrategy(batchingStrategy)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
Expand Down Expand Up @@ -57,6 +60,8 @@
* @author Josh Long
* @author Muthukumaran Navaneethakrishnan
* @author Thomas Vitale
* @author Soby Chacko
* @since 1.0.0
*/
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {

Expand Down Expand Up @@ -90,17 +95,19 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final boolean initializeSchema;

private int dimensions;
private final int dimensions;

private PgDistanceType distanceType;
private final PgDistanceType distanceType;

private ObjectMapper objectMapper = new ObjectMapper();
private final ObjectMapper objectMapper = new ObjectMapper();

private boolean removeExistingVectorStoreTable;
private final boolean removeExistingVectorStoreTable;

private PgIndexType createIndexMethod;
private final PgIndexType createIndexMethod;

private PgVectorSchemaValidator schemaValidator;
private final PgVectorSchemaValidator schemaValidator;

private final BatchingStrategy batchingStrategy;

public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
Expand Down Expand Up @@ -134,13 +141,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT

this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
ObservationRegistry.NOOP, null);
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) {
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
BatchingStrategy batchingStrategy) {

super(observationRegistry, customObservationConvention);

Expand All @@ -163,6 +171,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
this.createIndexMethod = createIndexMethod;
this.initializeSchema = initializeSchema;
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
this.batchingStrategy = batchingStrategy;
}

public PgDistanceType getDistanceType() {
Expand All @@ -174,6 +183,8 @@ public void doAdd(List<Document> documents) {

int size = documents.size();

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

this.jdbcTemplate.batchUpdate(
"INSERT INTO " + getFullyQualifiedTableName()
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
Expand All @@ -185,8 +196,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
var document = documents.get(i);
var content = document.getContent();
var json = toJson(document.getMetadata());
var embedding = embeddingModel.embed(document);
document.setEmbedding(embedding);
var embedding = document.getEmbedding();
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
Expand Down Expand Up @@ -497,6 +507,8 @@ public static class Builder {

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

@Nullable
private VectorStoreObservationConvention searchObservationConvention;

Expand Down Expand Up @@ -559,10 +571,16 @@ public Builder withSearchObservationConvention(VectorStoreObservationConvention
return this;
}

public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
this.batchingStrategy = batchingStrategy;
return this;
}

public PgVectorStore build() {
return new PgVectorStore(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate,
embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, indexType,
initializeSchema, observationRegistry, searchObservationConvention);
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
}

}
Expand Down

0 comments on commit 087de16

Please sign in to comment.