Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PgVectorStore: Add batching strategry for adding documents #1304

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.vectorstore.pgvector;

import javax.sql.DataSource;

import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -34,17 +37,26 @@
/**
* @author Christian Tzolov
* @author Josh Long
* @author Soby Chacko
* @since 1.0.0
*/
@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class)
@ConditionalOnClass({ PgVectorStore.class, DataSource.class, JdbcTemplate.class })
@EnableConfigurationProperties(PgVectorStoreProperties.class)
public class PgVectorStoreAutoConfiguration {

@Bean
@ConditionalOnMissingBean(BatchingStrategy.class)
BatchingStrategy pgVectorStoreBatchingStrategy() {
return new TokenCountBatchingStrategy();
}

@Bean
@ConditionalOnMissingBean
public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel,
PgVectorStoreProperties properties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<VectorStoreObservationConvention> customObservationConvention) {
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
BatchingStrategy batchingStrategy) {

var initializeSchema = properties.isInitializeSchema();

Expand All @@ -58,6 +70,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
.withInitializeSchema(initializeSchema)
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.withBatchingStrategy(batchingStrategy)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
Expand Down Expand Up @@ -57,6 +60,8 @@
* @author Josh Long
* @author Muthukumaran Navaneethakrishnan
* @author Thomas Vitale
* @author Soby Chacko
* @since 1.0.0
*/
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {

Expand Down Expand Up @@ -90,17 +95,19 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final boolean initializeSchema;

private int dimensions;
private final int dimensions;

private PgDistanceType distanceType;
private final PgDistanceType distanceType;

private ObjectMapper objectMapper = new ObjectMapper();
private final ObjectMapper objectMapper = new ObjectMapper();

private boolean removeExistingVectorStoreTable;
private final boolean removeExistingVectorStoreTable;

private PgIndexType createIndexMethod;
private final PgIndexType createIndexMethod;

private PgVectorSchemaValidator schemaValidator;
private final PgVectorSchemaValidator schemaValidator;

private final BatchingStrategy batchingStrategy;

public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
Expand Down Expand Up @@ -134,13 +141,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT

this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
ObservationRegistry.NOOP, null);
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) {
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
BatchingStrategy batchingStrategy) {

super(observationRegistry, customObservationConvention);

Expand All @@ -163,6 +171,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
this.createIndexMethod = createIndexMethod;
this.initializeSchema = initializeSchema;
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
this.batchingStrategy = batchingStrategy;
}

public PgDistanceType getDistanceType() {
Expand All @@ -174,6 +183,8 @@ public void doAdd(List<Document> documents) {

int size = documents.size();

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

this.jdbcTemplate.batchUpdate(
"INSERT INTO " + getFullyQualifiedTableName()
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
Expand All @@ -185,8 +196,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
var document = documents.get(i);
var content = document.getContent();
var json = toJson(document.getMetadata());
var embedding = embeddingModel.embed(document);
document.setEmbedding(embedding);
var embedding = document.getEmbedding();
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
Expand Down Expand Up @@ -497,6 +507,8 @@ public static class Builder {

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

@Nullable
private VectorStoreObservationConvention searchObservationConvention;

Expand Down Expand Up @@ -559,10 +571,16 @@ public Builder withSearchObservationConvention(VectorStoreObservationConvention
return this;
}

public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
this.batchingStrategy = batchingStrategy;
return this;
}

public PgVectorStore build() {
return new PgVectorStore(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate,
embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, indexType,
initializeSchema, observationRegistry, searchObservationConvention);
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
}

}
Expand Down