From c0299f1f84c50b24f071df537892d88a1e23a911 Mon Sep 17 00:00:00 2001
From: sinsy <550569627@qq.com>
Date: Fri, 9 Aug 2024 18:19:29 +0800
Subject: [PATCH 1/2] feat: segments insert data
---
.../spring-ai-pgvector-store/pom.xml | 6 ++
.../ai/vectorstore/PgVectorStore.java | 68 +++++++++++--------
2 files changed, 44 insertions(+), 30 deletions(-)
diff --git a/vector-stores/spring-ai-pgvector-store/pom.xml b/vector-stores/spring-ai-pgvector-store/pom.xml
index 0da1ef5752..9cdd33ed99 100644
--- a/vector-stores/spring-ai-pgvector-store/pom.xml
+++ b/vector-stores/spring-ai-pgvector-store/pom.xml
@@ -49,6 +49,12 @@
${pgvector.version}
+
+ org.apache.commons
+ commons-collections4
+ 4.5.0-M2
+
+
org.springframework.ai
diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java
index bf2c662f71..92be091bed 100644
--- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java
+++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java
@@ -24,6 +24,7 @@
import java.util.UUID;
import java.util.stream.IntStream;
+import org.apache.commons.collections4.ListUtils;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -153,38 +154,45 @@ public PgDistanceType getDistanceType() {
@Override
public void add(List documents) {
- int size = documents.size();
+ int segmentNum = 10;
+
+ List> segments = ListUtils.partition(documents, segmentNum);
+
+ for (List segment : segments) {
+ int size = segment.size();
+ this.jdbcTemplate.batchUpdate(
+ "INSERT INTO " + getFullyQualifiedTableName()
+ + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
+ + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ",
+ new BatchPreparedStatementSetter() {
+ @Override
+ public void setValues(PreparedStatement ps, int i) throws SQLException {
+
+ var document = segment.get(i);
+ var content = document.getContent();
+ var json = toJson(document.getMetadata());
+ var embedding = embeddingModel.embed(document);
+ document.setEmbedding(embedding);
+ var pGvector = new PGvector(toFloatArray(embedding));
+
+ StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
+ UUID.fromString(document.getId()));
+ StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
+ StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
+ StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
+ StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
+ StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
+ StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
+ }
+
+ @Override
+ public int getBatchSize() {
+ return size;
+ }
+ });
- this.jdbcTemplate.batchUpdate(
- "INSERT INTO " + getFullyQualifiedTableName()
- + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
- + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ",
- new BatchPreparedStatementSetter() {
- @Override
- public void setValues(PreparedStatement ps, int i) throws SQLException {
-
- var document = documents.get(i);
- var content = document.getContent();
- var json = toJson(document.getMetadata());
- var embedding = embeddingModel.embed(document);
- document.setEmbedding(embedding);
- var pGvector = new PGvector(toFloatArray(embedding));
-
- StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
- UUID.fromString(document.getId()));
- StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
- StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
- StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
- StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
- StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
- StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
- }
+ }
- @Override
- public int getBatchSize() {
- return size;
- }
- });
}
private String toJson(Map map) {
From b989c06ed4e7bf1c8676dcc0ab66177b4f0f75ce Mon Sep 17 00:00:00 2001
From: sinsy <550569627@qq.com>
Date: Fri, 9 Aug 2024 18:32:52 +0800
Subject: [PATCH 2/2] change segmentNum
---
.../java/org/springframework/ai/vectorstore/PgVectorStore.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java
index 92be091bed..23b489a3f8 100644
--- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java
+++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java
@@ -154,7 +154,7 @@ public PgDistanceType getDistanceType() {
@Override
public void add(List documents) {
- int segmentNum = 10;
+ int segmentNum = 100;
List> segments = ListUtils.partition(documents, segmentNum);