From c0299f1f84c50b24f071df537892d88a1e23a911 Mon Sep 17 00:00:00 2001 From: sinsy <550569627@qq.com> Date: Fri, 9 Aug 2024 18:19:29 +0800 Subject: [PATCH 1/2] feat: segments insert data --- .../spring-ai-pgvector-store/pom.xml | 6 ++ .../ai/vectorstore/PgVectorStore.java | 68 +++++++++++-------- 2 files changed, 44 insertions(+), 30 deletions(-) diff --git a/vector-stores/spring-ai-pgvector-store/pom.xml b/vector-stores/spring-ai-pgvector-store/pom.xml index 0da1ef5752..9cdd33ed99 100644 --- a/vector-stores/spring-ai-pgvector-store/pom.xml +++ b/vector-stores/spring-ai-pgvector-store/pom.xml @@ -49,6 +49,12 @@ ${pgvector.version} + + org.apache.commons + commons-collections4 + 4.5.0-M2 + + org.springframework.ai diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index bf2c662f71..92be091bed 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -24,6 +24,7 @@ import java.util.UUID; import java.util.stream.IntStream; +import org.apache.commons.collections4.ListUtils; import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -153,38 +154,45 @@ public PgDistanceType getDistanceType() { @Override public void add(List documents) { - int size = documents.size(); + int segmentNum = 10; + + List> segments = ListUtils.partition(documents, segmentNum); + + for (List segment : segments) { + int size = segment.size(); + this.jdbcTemplate.batchUpdate( + "INSERT INTO " + getFullyQualifiedTableName() + + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", + new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + + var document = segment.get(i); + var content = document.getContent(); + var json = toJson(document.getMetadata()); + var embedding = embeddingModel.embed(document); + document.setEmbedding(embedding); + var pGvector = new PGvector(toFloatArray(embedding)); + + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, + UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); + StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); + } + + @Override + public int getBatchSize() { + return size; + } + }); - this.jdbcTemplate.batchUpdate( - "INSERT INTO " + getFullyQualifiedTableName() - + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " - + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - - var document = documents.get(i); - var content = document.getContent(); - var json = toJson(document.getMetadata()); - var embedding = embeddingModel.embed(document); - document.setEmbedding(embedding); - var pGvector = new PGvector(toFloatArray(embedding)); - - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); - StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); - StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); - } + } - @Override - public int getBatchSize() { - return size; - } - }); } private String toJson(Map map) { From b989c06ed4e7bf1c8676dcc0ab66177b4f0f75ce Mon Sep 17 00:00:00 2001 From: sinsy <550569627@qq.com> Date: Fri, 9 Aug 2024 18:32:52 +0800 Subject: [PATCH 2/2] change segmentNum --- .../java/org/springframework/ai/vectorstore/PgVectorStore.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 92be091bed..23b489a3f8 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -154,7 +154,7 @@ public PgDistanceType getDistanceType() { @Override public void add(List documents) { - int segmentNum = 10; + int segmentNum = 100; List> segments = ListUtils.partition(documents, segmentNum);