From 085a6a2ee08deaeaf3fedffe3d5df145fde6e200 Mon Sep 17 00:00:00 2001 From: Angular2guy Date: Sat, 20 Jan 2024 10:25:49 +0100 Subject: [PATCH] feat: use row embeddings --- .../repository/DocumentVSRepositoryBean.java | 30 +++++++++++-------- .../usecase/service/TableService.java | 7 +++++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java b/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java index 57a46d2..8130af1 100644 --- a/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java +++ b/backend/src/main/java/ch/xxx/aidoclibchat/adapter/repository/DocumentVSRepositoryBean.java @@ -45,17 +45,17 @@ public void add(List documents) { @Override public List retrieve(String query, DataType dataType, int k, double threshold) { - return this.vectorStore.similaritySearch(SearchRequest.query(query) - .withFilterExpression( - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), new Value(dataType.toString()))) - .withTopK(k).withSimilarityThreshold(threshold)); + return this.vectorStore + .similaritySearch(SearchRequest + .query(query).withFilterExpression(new Filter.Expression(ExpressionType.EQ, + new Key(MetaData.DATATYPE), new Value(dataType.toString()))) + .withTopK(k).withSimilarityThreshold(threshold)); } @Override public List retrieve(String query, DataType dataType, int k) { - return this.vectorStore.similaritySearch(SearchRequest.query(query) - .withFilterExpression( - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), new Value(dataType.toString()))) + return this.vectorStore.similaritySearch(SearchRequest.query(query).withFilterExpression( + new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), new Value(dataType.toString()))) .withTopK(k)); } @@ -67,12 +67,18 @@ public List retrieve(String query, DataType dataType) { @Override public List findAllTableDocuments() { - return this.vectorStore.similaritySearch(SearchRequest.defaults().withSimilarityThresholdAll().withTopK(Integer.MAX_VALUE).withFilterExpression(new Filter.Expression( - ExpressionType.OR, - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), new Value(DataType.COLUMN.toString())), - new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), new Value(DataType.TABLE.toString()))))); + return this.vectorStore + .similaritySearch(SearchRequest.defaults().withSimilarityThresholdAll().withTopK(Integer.MAX_VALUE) + .withFilterExpression(new Filter.Expression(ExpressionType.OR, + new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), + new Value(DataType.COLUMN.toString())), + new Filter.Expression(ExpressionType.OR, + new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), + new Value(DataType.TABLE.toString())), + new Filter.Expression(ExpressionType.EQ, new Key(MetaData.DATATYPE), + new Value(DataType.ROW.toString())))))); } - + @Override public void deleteByIds(List ids) { this.vectorStore.delete(ids); diff --git a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java index 23c9a7f..b99cb07 100644 --- a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java +++ b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/TableService.java @@ -67,10 +67,17 @@ public void searchTables(SearchDto searchDto) { searchDto.getResultAmount()); var columnDocuments = this.documentVsRepository.retrieve(searchDto.getSearchString(), MetaData.DataType.COLUMN, searchDto.getResultAmount()); + var rowDocuments = this.documentVsRepository.retrieve(searchDto.getSearchString(), MetaData.DataType.ROW, + searchDto.getResultAmount()); + LOGGER.info("Table: "); tableDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE))); + LOGGER.info("Column: "); columnDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME), myDoc.getMetadata().get(MetaData.DISTANCE))); + LOGGER.info("Row: "); + rowDocuments.forEach(myDoc -> LOGGER.info("name: {}, distance: {}", myDoc.getMetadata().get(MetaData.DATANAME), + myDoc.getMetadata().get(MetaData.DISTANCE))); } @Async