From d303d25895288fcdd96b662712d0aa376352e6e0 Mon Sep 17 00:00:00 2001 From: Pablo Sanchidrian Date: Wed, 25 Dec 2024 13:25:39 +0100 Subject: [PATCH 1/2] refactor: minor code improvements --- .../main/java/jchunk/chunker/fixed/Utils.java | 60 +++++++++++++------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/jchunk-fixed/src/main/java/jchunk/chunker/fixed/Utils.java b/jchunk-fixed/src/main/java/jchunk/chunker/fixed/Utils.java index baca4dc..c0a6493 100644 --- a/jchunk-fixed/src/main/java/jchunk/chunker/fixed/Utils.java +++ b/jchunk-fixed/src/main/java/jchunk/chunker/fixed/Utils.java @@ -2,9 +2,7 @@ import jchunk.chunker.core.chunk.Chunk; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; import java.util.regex.Pattern; @@ -50,7 +48,7 @@ public static List splitIntoSentences(String content, Config config) { private static List splitWithDelimiter(String content, String delimiter, Config.Delimiter keepDelimiter) { if (keepDelimiter == Config.Delimiter.NONE) { - return Arrays.stream(content.split(Pattern.quote(delimiter))).filter(s -> !s.isBlank()).toList(); + return Arrays.stream(content.split(Pattern.quote(delimiter))).filter(s -> !s.isEmpty()).toList(); } String withDelimiter = "((?<=%1$s)|(?=%1$s))"; @@ -70,11 +68,11 @@ private static List splitWithDelimiterStart(List preSplits) { List splits = new ArrayList<>(); splits.add(preSplits.getFirst()); - IntStream.range(1, preSplits.size()) + IntStream.range(1, preSplits.size() - 1) .filter(i -> i % 2 == 1) .forEach(i -> splits.add(preSplits.get(i).concat(preSplits.get(i + 1)))); - return splits.stream().filter(s -> !s.isBlank()).toList(); + return splits.stream().filter(s -> !s.isEmpty()).toList(); } /** @@ -91,7 +89,7 @@ private static List splitWithDelimiterEnd(List preSplits) { .forEach(i -> splits.add(preSplits.get(i).concat(preSplits.get(i + 1)))); splits.add(preSplits.getLast()); - return splits.stream().filter(s -> !s.isBlank()).toList(); + return splits.stream().filter(s -> !s.isEmpty()).toList(); } /** @@ -110,7 +108,7 @@ static List mergeSentences(List sentences, Config config) { int delimiterLen = delimiter.length(); List chunks = new ArrayList<>(); - List currentChunk = new ArrayList<>(); + Deque currentChunk = new LinkedList<>(); AtomicInteger chunkIndex = new AtomicInteger(0); @@ -123,14 +121,8 @@ static List mergeSentences(List sentences, Config config) { } if (!currentChunk.isEmpty()) { - String generatedSentence = joinSentences(currentChunk, delimiter, trimWhitespace); - chunks.add(new Chunk(chunkIndex.getAndIncrement(), generatedSentence)); - - while (currentLen > chunkOverlap - || (currentLen + sentenceLength + (currentChunk.isEmpty() ? 0 : delimiterLen) > chunkSize - && currentLen > 0)) { - currentLen -= currentChunk.removeFirst().length() + (currentChunk.isEmpty() ? 0 : delimiterLen); - } + addChunk(chunks, currentChunk, delimiter, trimWhitespace, chunkIndex); + currentLen = adjustCurrentChunkForOverlap(currentChunk, currentLen, chunkOverlap, delimiterLen); } } @@ -139,13 +131,43 @@ static List mergeSentences(List sentences, Config config) { } if (!currentChunk.isEmpty()) { - String generatedSentence = joinSentences(currentChunk, config.getDelimiter(), config.getTrimWhitespace()); - chunks.add(new Chunk(chunkIndex.getAndIncrement(), generatedSentence)); + addChunk(chunks, currentChunk, delimiter, trimWhitespace, chunkIndex); } return chunks; } + /** + * Adds the chunk to the list of chunks. + * @param chunks the list of chunks + * @param currentChunk the current chunk + * @param delimiter the delimiter + * @param trimWhitespace whether to trim the whitespace + * @param index the index of the chunk + */ + private static void addChunk(List chunks, Deque currentChunk, String delimiter, + boolean trimWhitespace, AtomicInteger index) { + String generatedSentence = joinSentences(currentChunk, delimiter, trimWhitespace); + Chunk chunk = Chunk.builder().id(index.getAndIncrement()).content(generatedSentence).build(); + chunks.add(chunk); + } + + /** + * Adjusts the current chunk for overlap. + * @param currentChunk the current chunk + * @param currentLen the current length of the chunk + * @param chunkOverlap the overlap between chunks + * @param delimiterLen the length of the delimiter + * @return the adjusted length of the chunk + */ + private static int adjustCurrentChunkForOverlap(Deque currentChunk, int currentLen, int chunkOverlap, + int delimiterLen) { + while (currentLen > chunkOverlap && !currentChunk.isEmpty()) { + currentLen -= currentChunk.removeFirst().length() + (currentChunk.isEmpty() ? 0 : delimiterLen); + } + return currentLen; + } + /** * Joins the sentences into a single sentence. * @param sentences the sentences to join @@ -153,7 +175,7 @@ static List mergeSentences(List sentences, Config config) { * @param trimWhitespace whether to trim the whitespace * @return the generated sentence */ - private static String joinSentences(List sentences, String delimiter, Boolean trimWhitespace) { + private static String joinSentences(Deque sentences, String delimiter, Boolean trimWhitespace) { String generatedSentence = String.join(delimiter, sentences); if (trimWhitespace) { generatedSentence = generatedSentence.trim(); From 0bb60b97b8a9ad693f769e798a45863fff8a2520 Mon Sep 17 00:00:00 2001 From: Pablo Sanchidrian Date: Wed, 25 Dec 2024 19:07:44 +0100 Subject: [PATCH 2/2] refactor: semantic chunker --- .../core/decorators/VisibleForTesting.java | 16 + .../chunker/semantic/SemanticChunker.java | 188 +++++++++- .../java/jchunk/chunker/semantic/Utils.java | 190 ---------- .../chunker/semantic/SemanticChunkerIT.java | 77 +++- .../chunker/semantic/SemanticChunkerTest.java | 355 ++++++++++++++++++ .../semantic/SemanticChunkerUtilsTest.java | 261 ------------- 6 files changed, 610 insertions(+), 477 deletions(-) create mode 100644 jchunk-core/src/main/java/jchunk/chunker/core/decorators/VisibleForTesting.java delete mode 100644 jchunk-semantic/src/main/java/jchunk/chunker/semantic/Utils.java create mode 100644 jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerTest.java delete mode 100644 jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerUtilsTest.java diff --git a/jchunk-core/src/main/java/jchunk/chunker/core/decorators/VisibleForTesting.java b/jchunk-core/src/main/java/jchunk/chunker/core/decorators/VisibleForTesting.java new file mode 100644 index 0000000..320153f --- /dev/null +++ b/jchunk-core/src/main/java/jchunk/chunker/core/decorators/VisibleForTesting.java @@ -0,0 +1,16 @@ +package jchunk.chunker.core.decorators; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to indicate that the visibility of a method, field, or class is increased + * for testing purposes. + */ +@Retention(RetentionPolicy.CLASS) +@Target({ ElementType.METHOD, ElementType.FIELD, ElementType.CONSTRUCTOR, ElementType.TYPE }) +public @interface VisibleForTesting { + +} \ No newline at end of file diff --git a/jchunk-semantic/src/main/java/jchunk/chunker/semantic/SemanticChunker.java b/jchunk-semantic/src/main/java/jchunk/chunker/semantic/SemanticChunker.java index 0182659..5fe7317 100644 --- a/jchunk-semantic/src/main/java/jchunk/chunker/semantic/SemanticChunker.java +++ b/jchunk-semantic/src/main/java/jchunk/chunker/semantic/SemanticChunker.java @@ -2,9 +2,17 @@ import jchunk.chunker.core.chunk.Chunk; import jchunk.chunker.core.chunk.IChunker; +import jchunk.chunker.core.decorators.VisibleForTesting; +import org.nd4j.common.io.Assert; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.springframework.ai.embedding.EmbeddingModel; +import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * A semantic chunker that chunks the content based on the semantic meaning @@ -32,12 +40,180 @@ public SemanticChunker(EmbeddingModel embeddingModel, Config config) { @Override public List split(String content) { - List sentences = Utils.splitSentences(content, config.getSentenceSplitingStrategy()); - sentences = Utils.combineSentences(sentences, config.getBufferSize()); - sentences = Utils.embedSentences(embeddingModel, sentences); - List similarities = Utils.calculateSimilarities(sentences); - List breakPoints = Utils.calculateBreakPoints(similarities, config.getPercentile()); - return Utils.generateChunks(sentences, breakPoints); + List sentences = splitSentences(content, config.getSentenceSplitingStrategy()); + sentences = combineSentences(sentences, config.getBufferSize()); + sentences = embedSentences(embeddingModel, sentences); + List similarities = calculateSimilarities(sentences); + List breakPoints = calculateBreakPoints(similarities, config.getPercentile()); + return generateChunks(sentences, breakPoints); + } + + /** + * Split the content into sentences + * @param content the content to split + * @return the list of sentences + */ + @VisibleForTesting + private List splitSentences(String content, SentenceSplitingStrategy splitingStrategy) { + AtomicInteger index = new AtomicInteger(0); + return Arrays.stream(content.split(splitingStrategy.getStrategy())) + .map(sentence -> Sentence.builder().content(sentence).index(index.getAndIncrement()).build()) + .toList(); + } + + /** + * Combine the sentences based on the buffer size (append the buffer size of sentences + * behind and over the current sentence) + *

+ * Use the sliding window technique to reduce the time complexity + * @param sentences the list of sentences + * @param bufferSize the buffer size to use + * @return the list of combined sentences + */ + @VisibleForTesting + private List combineSentences(List sentences, Integer bufferSize) { + assert sentences != null : "The list of sentences cannot be null"; + assert !sentences.isEmpty() : "The list of sentences cannot be empty"; + assert bufferSize != null && bufferSize > 0 : "The buffer size cannot be null nor 0"; + assert bufferSize < sentences.size() : "The buffer size cannot be greater equal than the input length"; + + int n = sentences.size(); + int windowSize = bufferSize * 2 + 1; + int currentWindowSize = 0; + StringBuilder windowBuilder = new StringBuilder(); + + for (int i = 0; i <= Math.min(bufferSize, n - 1); i++) { + windowBuilder.append(sentences.get(i).getContent()).append(" "); + currentWindowSize++; + } + + windowBuilder.deleteCharAt(windowBuilder.length() - 1); + + for (int i = 0; i < n; ++i) { + sentences.get(i).setCombined(windowBuilder.toString()); + + if (currentWindowSize < windowSize && i + bufferSize + 1 < n) { + windowBuilder.append(" ").append(sentences.get(i + bufferSize + 1).getContent()); + currentWindowSize++; + } + else { + windowBuilder.delete(0, sentences.get(i - bufferSize).getContent().length() + 1); + if (i + bufferSize + 1 < n) { + windowBuilder.append(" ").append(sentences.get(i + bufferSize + 1).getContent()); + } + else { + currentWindowSize--; + } + } + } + + return sentences; + } + + /** + * Embed the sentences using the embedding model + * @param sentences the list of sentences + * @return the list of sentences with the embeddings + */ + @VisibleForTesting + private List embedSentences(EmbeddingModel embeddingModel, List sentences) { + + List sentencesText = sentences.stream().map(Sentence::getContent).toList(); + + List embeddings = embeddingModel.embed(sentencesText); + + return IntStream.range(0, sentences.size()).mapToObj(i -> { + Sentence sentence = sentences.get(i); + sentence.setEmbedding(embeddings.get(i)); + return sentence; + }).toList(); + } + + /** + * Calculate the similarity between the sentences embeddings + * @param sentence1 the first sentence embedding + * @param sentence2 the second sentence embedding + * @return the cosine similarity between the sentences + */ + @VisibleForTesting + private Double cosineSimilarity(float[] sentence1, float[] sentence2) { + assert sentence1 != null : "The first sentence embedding cannot be null"; + assert sentence2 != null : "The second sentence embedding cannot be null"; + assert sentence1.length == sentence2.length : "The sentence embeddings must have the same size"; + + INDArray arrayA = Nd4j.create(sentence1); + INDArray arrayB = Nd4j.create(sentence2); + + arrayA = arrayA.div(arrayA.norm2Number()); + arrayB = arrayB.div(arrayB.norm2Number()); + + return Nd4j.getBlasWrapper().dot(arrayA, arrayB); + } + + /** + * Calculate the similarity between the sentences embeddings + * @param sentences the list of sentences + * @return the list of similarities (List of double) + */ + @VisibleForTesting + private List calculateSimilarities(List sentences) { + return IntStream.range(0, sentences.size() - 1).parallel().mapToObj(i -> { + Sentence sentence1 = sentences.get(i); + Sentence sentence2 = sentences.get(i + 1); + return cosineSimilarity(sentence1.getEmbedding(), sentence2.getEmbedding()); + }).toList(); + } + + /** + * Calculate the break points indices based on the similarities and the threshold + * @param distances the list of cosine similarities between the sentences + * @return the list of break points indices + */ + @VisibleForTesting + private List calculateBreakPoints(List distances, Integer percentile) { + Assert.isTrue(distances != null, "The list of distances cannot be null"); + + double breakpointDistanceThreshold = calculatePercentile(distances, percentile); + + return IntStream.range(0, distances.size()) + .filter(i -> distances.get(i) >= breakpointDistanceThreshold) + .boxed() + .toList(); + } + + private static Double calculatePercentile(List distances, int percentile) { + Assert.isTrue(distances != null, "The list of distances cannot be null"); + Assert.isTrue(percentile > 0 && percentile < 100, "The percentile must be between 0 and 100"); + + distances = distances.stream().sorted().toList(); + + int rank = (int) Math.ceil(percentile / 100.0 * distances.size()); + return distances.get(rank - 1); + } + + /** + * Generate chunks combining the sentences based on the break points + * @param sentences the list of sentences + * @param breakPoints the list of break points indices + * @return the list of chunks + */ + @VisibleForTesting + private List generateChunks(List sentences, List breakPoints) { + Assert.isTrue(sentences != null, "The list of sentences cannot be null"); + Assert.isTrue(!sentences.isEmpty(), "The list of sentences cannot be empty"); + Assert.isTrue(breakPoints != null, "The list of break points cannot be null"); + + AtomicInteger index = new AtomicInteger(0); + + return IntStream.range(0, breakPoints.size() + 1).mapToObj(i -> { + int start = i == 0 ? 0 : breakPoints.get(i - 1) + 1; + int end = i == breakPoints.size() ? sentences.size() : breakPoints.get(i) + 1; + String content = sentences.subList(start, end) + .stream() + .map(Sentence::getContent) + .collect(Collectors.joining(" ")); + return new Chunk(index.getAndIncrement(), content); + }).toList(); } } diff --git a/jchunk-semantic/src/main/java/jchunk/chunker/semantic/Utils.java b/jchunk-semantic/src/main/java/jchunk/chunker/semantic/Utils.java deleted file mode 100644 index fe373a9..0000000 --- a/jchunk-semantic/src/main/java/jchunk/chunker/semantic/Utils.java +++ /dev/null @@ -1,190 +0,0 @@ -package jchunk.chunker.semantic; - -import jchunk.chunker.core.chunk.Chunk; -import org.nd4j.common.io.Assert; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.springframework.ai.embedding.EmbeddingModel; - -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** - * Utility class for semantic chunking. Wraps the different methods needed to chunk a text - * semantically. - * - * @author Pablo Sanchidrian Herrera - */ -public class Utils { - - /** - * Private constructor to hide the implicit public one - */ - private Utils() { - } - - /** - * Split the content into sentences - * @param content the content to split - * @return the list of sentences - */ - public static List splitSentences(String content, SentenceSplitingStrategy splitingStrategy) { - AtomicInteger index = new AtomicInteger(0); - return Arrays.stream(content.split(splitingStrategy.getStrategy())) - .map(sentence -> Sentence.builder().content(sentence).index(index.getAndIncrement()).build()) - .toList(); - } - - /** - * Combine the sentences based on the buffer size (append the buffer size of sentences - * behind and over the current sentence) - *

- * Use the sliding window technique to reduce the time complexity - * @param sentences the list of sentences - * @param bufferSize the buffer size to use - * @return the list of combined sentences - */ - public static List combineSentences(List sentences, Integer bufferSize) { - assert sentences != null : "The list of sentences cannot be null"; - assert !sentences.isEmpty() : "The list of sentences cannot be empty"; - assert bufferSize != null && bufferSize > 0 : "The buffer size cannot be null nor 0"; - assert bufferSize < sentences.size() : "The buffer size cannot be greater equal than the input length"; - - int n = sentences.size(); - int windowSize = bufferSize * 2 + 1; - int currentWindowSize = 0; - StringBuilder windowBuilder = new StringBuilder(); - - for (int i = 0; i <= Math.min(bufferSize, n - 1); i++) { - windowBuilder.append(sentences.get(i).getContent()).append(" "); - currentWindowSize++; - } - - windowBuilder.deleteCharAt(windowBuilder.length() - 1); - - for (int i = 0; i < n; ++i) { - sentences.get(i).setCombined(windowBuilder.toString()); - - if (currentWindowSize < windowSize && i + bufferSize + 1 < n) { - windowBuilder.append(" ").append(sentences.get(i + bufferSize + 1).getContent()); - currentWindowSize++; - } - else { - windowBuilder.delete(0, sentences.get(i - bufferSize).getContent().length() + 1); - if (i + bufferSize + 1 < n) { - windowBuilder.append(" ").append(sentences.get(i + bufferSize + 1).getContent()); - } - else { - currentWindowSize--; - } - } - } - - return sentences; - } - - /** - * Embed the sentences using the embedding model - * @param sentences the list of sentences - * @return the list of sentences with the embeddings - */ - public static List embedSentences(EmbeddingModel embeddingModel, List sentences) { - - List sentencesText = sentences.stream().map(Sentence::getContent).toList(); - - List embeddings = embeddingModel.embed(sentencesText); - - return IntStream.range(0, sentences.size()).mapToObj(i -> { - Sentence sentence = sentences.get(i); - sentence.setEmbedding(embeddings.get(i)); - return sentence; - }).toList(); - } - - /** - * Calculate the similarity between the sentences embeddings - * @param sentence1 the first sentence embedding - * @param sentence2 the second sentence embedding - * @return the cosine similarity between the sentences - */ - public static Double cosineSimilarity(float[] sentence1, float[] sentence2) { - assert sentence1 != null : "The first sentence embedding cannot be null"; - assert sentence2 != null : "The second sentence embedding cannot be null"; - assert sentence1.length == sentence2.length : "The sentence embeddings must have the same size"; - - INDArray arrayA = Nd4j.create(sentence1); - INDArray arrayB = Nd4j.create(sentence2); - - arrayA = arrayA.div(arrayA.norm2Number()); - arrayB = arrayB.div(arrayB.norm2Number()); - - return Nd4j.getBlasWrapper().dot(arrayA, arrayB); - } - - /** - * Calculate the similarity between the sentences embeddings - * @param sentences the list of sentences - * @return the list of similarities (List of double) - */ - public static List calculateSimilarities(List sentences) { - return IntStream.range(0, sentences.size() - 1).parallel().mapToObj(i -> { - Sentence sentence1 = sentences.get(i); - Sentence sentence2 = sentences.get(i + 1); - return cosineSimilarity(sentence1.getEmbedding(), sentence2.getEmbedding()); - }).toList(); - } - - /** - * Calculate the break points indices based on the similarities and the threshold - * @param distances the list of cosine similarities between the sentences - * @return the list of break points indices - */ - public static List calculateBreakPoints(List distances, Integer percentile) { - Assert.isTrue(distances != null, "The list of distances cannot be null"); - - double breakpointDistanceThreshold = calculatePercentile(distances, percentile); - - return IntStream.range(0, distances.size()) - .filter(i -> distances.get(i) >= breakpointDistanceThreshold) - .boxed() - .toList(); - } - - private static Double calculatePercentile(List distances, int percentile) { - Assert.isTrue(distances != null, "The list of distances cannot be null"); - Assert.isTrue(percentile > 0 && percentile < 100, "The percentile must be between 0 and 100"); - - distances = distances.stream().sorted().toList(); - - int rank = (int) Math.ceil(percentile / 100.0 * distances.size()); - return distances.get(rank - 1); - } - - /** - * Generate chunks combining the sentences based on the break points - * @param sentences the list of sentences - * @param breakPoints the list of break points indices - * @return the list of chunks - */ - public static List generateChunks(List sentences, List breakPoints) { - Assert.isTrue(sentences != null, "The list of sentences cannot be null"); - Assert.isTrue(!sentences.isEmpty(), "The list of sentences cannot be empty"); - Assert.isTrue(breakPoints != null, "The list of break points cannot be null"); - - AtomicInteger index = new AtomicInteger(0); - - return IntStream.range(0, breakPoints.size() + 1).mapToObj(i -> { - int start = i == 0 ? 0 : breakPoints.get(i - 1) + 1; - int end = i == breakPoints.size() ? sentences.size() : breakPoints.get(i) + 1; - String content = sentences.subList(start, end) - .stream() - .map(Sentence::getContent) - .collect(Collectors.joining(" ")); - return new Chunk(index.getAndIncrement(), content); - }).toList(); - } - -} diff --git a/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerIT.java b/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerIT.java index b4e7a61..c37b0ce 100644 --- a/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerIT.java +++ b/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerIT.java @@ -1,6 +1,7 @@ package jchunk.chunker.semantic; import jchunk.chunker.core.chunk.Chunk; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.ai.embedding.EmbeddingModel; @@ -14,6 +15,8 @@ import org.springframework.core.io.DefaultResourceLoader; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.util.List; @@ -23,6 +26,14 @@ @Disabled("Only for manual testing purposes.") class SemanticChunkerIT { + private static Method splitSentences; + + private static Method combineSentences; + + private static Method embedSentences; + + private static Method calculateSimilarities; + @Autowired private SemanticChunker semanticChunker; @@ -41,21 +52,48 @@ static String getText(String uri) { } } + static void initPrivateMethods() throws NoSuchMethodException { + splitSentences = SemanticChunker.class.getDeclaredMethod("splitSentences", String.class, + SentenceSplitingStrategy.class); + combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + embedSentences = SemanticChunker.class.getDeclaredMethod("embedSentences", EmbeddingModel.class, List.class); + calculateSimilarities = SemanticChunker.class.getDeclaredMethod("calculateSimilarities", List.class); + splitSentences.setAccessible(true); + combineSentences.setAccessible(true); + embedSentences.setAccessible(true); + calculateSimilarities.setAccessible(true); + } + + @BeforeAll + static void init() throws NoSuchMethodException { + initPrivateMethods(); + } + @Test void documentContentLoaded() { assertThat(mitContent).isNotBlank(); } @Test - void getSentences() { - List sentences = Utils.splitSentences(mitContent, SentenceSplitingStrategy.DEFAULT); + void getChunks() { + List chunks = this.semanticChunker.split(mitContent); + assertThat(chunks).isNotEmpty(); + } + + @Test + @SuppressWarnings("unchecked") + void getSentences() throws InvocationTargetException, IllegalAccessException { + List sentences = (List) splitSentences.invoke(semanticChunker, mitContent, + SentenceSplitingStrategy.DEFAULT); assertThat(sentences).isNotEmpty().hasSize(317); } @Test - void combineSentences() { - List sentences = Utils.splitSentences(mitContent, SentenceSplitingStrategy.DEFAULT); - List combined = Utils.combineSentences(sentences, 1); + @SuppressWarnings("unchecked") + void combineSentences() throws InvocationTargetException, IllegalAccessException { + List sentences = (List) splitSentences.invoke(semanticChunker, mitContent, + SentenceSplitingStrategy.DEFAULT); + List combined = (List) combineSentences.invoke(semanticChunker, sentences, 1); assertThat(combined).isNotEmpty(); assertThat(combined).hasSize(317); @@ -67,12 +105,15 @@ void combineSentences() { } @Test - void embedChunks() { + @SuppressWarnings("unchecked") + void embedChunks() throws InvocationTargetException, IllegalAccessException { + int EMBEDDING_MODEL_DIMENSION = 384; - List sentences = Utils.splitSentences(mitContent, SentenceSplitingStrategy.DEFAULT); - List combined = Utils.combineSentences(sentences, 1); - List embedded = Utils.embedSentences(embeddingModel, combined); + List sentences = (List) splitSentences.invoke(semanticChunker, mitContent, + SentenceSplitingStrategy.DEFAULT); + List combined = (List) combineSentences.invoke(semanticChunker, sentences, 1); + List embedded = (List) embedSentences.invoke(semanticChunker, embeddingModel, combined); assertThat(embedded).isNotEmpty().hasSize(317); @@ -84,21 +125,17 @@ void embedChunks() { } @Test - void getCosineDistancesArray() { - List sentences = Utils.splitSentences(mitContent, SentenceSplitingStrategy.DEFAULT); - List combined = Utils.combineSentences(sentences, 1); - List embedded = Utils.embedSentences(embeddingModel, combined); - List distances = Utils.calculateSimilarities(embedded); + @SuppressWarnings("unchecked") + void getCosineDistancesArray() throws InvocationTargetException, IllegalAccessException { + List sentences = (List) splitSentences.invoke(semanticChunker, mitContent, + SentenceSplitingStrategy.DEFAULT); + List combined = (List) combineSentences.invoke(semanticChunker, sentences, 1); + List embedded = (List) embedSentences.invoke(semanticChunker, embeddingModel, combined); + List distances = (List) calculateSimilarities.invoke(semanticChunker, embedded); assertThat(distances).hasSize(sentences.size() - 1); } - @Test - void getChunks() { - List chunks = this.semanticChunker.split(mitContent); - assertThat(chunks).isNotEmpty(); - } - @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { diff --git a/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerTest.java b/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerTest.java new file mode 100644 index 0000000..a864c39 --- /dev/null +++ b/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerTest.java @@ -0,0 +1,355 @@ +package jchunk.chunker.semantic; + +import jchunk.chunker.core.chunk.Chunk; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.ai.embedding.EmbeddingModel; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +public class SemanticChunkerTest { + + final static double MARGIN = 0.0001f; + + final EmbeddingModel embeddingModel; + + final SemanticChunker semanticChunker; + + SemanticChunkerTest() { + this.embeddingModel = Mockito.mock(EmbeddingModel.class); + this.semanticChunker = new SemanticChunker(embeddingModel); + } + + // @formatter:off + + @Test + void splitSentenceDefaultStrategyTest() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method splitSentences = SemanticChunker.class.getDeclaredMethod("splitSentences", String.class, SentenceSplitingStrategy.class); + splitSentences.setAccessible(true); + + List expectedResult = List.of( + Sentence.builder().content("This is a test sentence.").build(), + Sentence.builder().content("How are u?").build(), + Sentence.builder().content("I am fine thanks\nI am a test sentence!").build(), + Sentence.builder().content("sure").build() + ); + + String content = "This is a test sentence. How are u? I am fine thanks\nI am a test sentence! sure"; + + List result = (List) splitSentences.invoke(semanticChunker, content, SentenceSplitingStrategy.DEFAULT); + + assertThat(result).isNotNull().hasSize(expectedResult.size()); + + for (int i = 0; i < result.size(); i++) { + assertThat(result.get(i).getContent()).isEqualTo(expectedResult.get(i).getContent()); + } + } + + @Test + @SuppressWarnings("unchecked") + void splitSentenceStrategyTest() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method splitSentences = SemanticChunker.class.getDeclaredMethod("splitSentences", String.class, SentenceSplitingStrategy.class); + splitSentences.setAccessible(true); + + List expectedResult = List.of( + Sentence.builder().content("This is a test sentence. How are u? I am fine thanks").build(), + Sentence.builder().content("I am a test sentence! sure").build() + ); + + String content = "This is a test sentence. How are u? I am fine thanks\nI am a test sentence! sure"; + List result = (List) splitSentences.invoke(semanticChunker, content, SentenceSplitingStrategy.LINE_BREAK); + + assertThat(result).isNotNull().hasSize(expectedResult.size()); + + assertThat(result.get(0).getContent()).isEqualTo(expectedResult.get(0).getContent()); + assertThat(result.get(1).getContent()).isEqualTo(expectedResult.get(1).getContent()); + } + + @Test + @SuppressWarnings("unchecked") + void splitSentenceParagraphStrategyTest() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method splitSentences = SemanticChunker.class.getDeclaredMethod("splitSentences", String.class, SentenceSplitingStrategy.class); + splitSentences.setAccessible(true); + + List expectedResult = List.of(Sentence.builder().index(0).content("This is a test sentence.").build(), + Sentence.builder().index(1).content("How are u? I am fine thanks").build(), + Sentence.builder().index(2).content("I am a test sentence!\nsure").build()); + + String content = "This is a test sentence.\n\nHow are u? I am fine thanks\n\nI am a test sentence!\nsure"; + List result = (List) splitSentences.invoke(semanticChunker, content, SentenceSplitingStrategy.PARAGRAPH); + + assertThat(result).isNotNull().hasSize(expectedResult.size()); + + assertThat(result.get(0).getContent()).isEqualTo(expectedResult.get(0).getContent()); + assertThat(result.get(1).getContent()).isEqualTo(expectedResult.get(1).getContent()); + assertThat(result.get(2).getContent()).isEqualTo(expectedResult.get(2).getContent()); + } + + @Test + @SuppressWarnings("unchecked") + void combineSentencesSuccessTest() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + combineSentences.setAccessible(true); + + Integer bufferSize = 2; + List input = List.of( + Sentence.builder().index(0).content("This").build(), + Sentence.builder().index(1).content("is").build(), Sentence.builder().index(2).content("a").build(), + Sentence.builder().index(3).content("sentence").build(), + Sentence.builder().index(4).content("for").build(), Sentence.builder().index(5).content("you").build(), + Sentence.builder().index(6).content("mate").build() + ); + + List expectedResult = List.of( + Sentence.builder().index(0).content("This").combined("This is a").build(), + Sentence.builder().index(1).content("is").combined("This is a sentence").build(), + Sentence.builder().index(2).content("a").combined("This is a sentence for").build(), + Sentence.builder().index(3).content("sentence").combined("is a sentence for you").build(), + Sentence.builder().index(4).content("for").combined("a sentence for you mate").build(), + Sentence.builder().index(5).content("you").combined("sentence for you mate").build(), + Sentence.builder().index(6).content("mate").combined("for you mate").build() + ); + + List result = (List) combineSentences.invoke(semanticChunker, input, bufferSize); + + assertThat(result).isNotNull(); + assertThat(result.size()).isEqualTo(expectedResult.size()); + + for (int i = 0; i < result.size(); i++) { + assertThat(result.get(i).getIndex()).isEqualTo(expectedResult.get(i).getIndex()); + assertThat(result.get(i).getContent()).isEqualTo(expectedResult.get(i).getContent()); + assertThat(result.get(i).getCombined()).isEqualTo(expectedResult.get(i).getCombined()); + } + } + + @Test + void combineSentencesWithBufferSizeEqualZeroTest() throws NoSuchMethodException { + Method combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + combineSentences.setAccessible(true); + + Integer bufferSize = 0; + List input = List.of(Sentence.builder().content("This").build()); + + assertThatThrownBy(() -> { + try { combineSentences.invoke(semanticChunker, input, bufferSize); } + catch (InvocationTargetException e) { throw e.getCause(); } + }) + .isInstanceOf(AssertionError.class) + .hasMessage("The buffer size cannot be null nor 0"); + } + + @Test + void combineSentencesWithBufferSizeIsNullTest() throws NoSuchMethodException { + Method combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + combineSentences.setAccessible(true); + + Integer bufferSize = null; + List input = List.of(Sentence.builder().content("This").build()); + + assertThatThrownBy(() -> { + try { combineSentences.invoke(semanticChunker, input, bufferSize); } + catch (InvocationTargetException e) { throw e.getCause(); } + }).isInstanceOf(AssertionError.class) + .hasMessage("The buffer size cannot be null nor 0"); + } + + @Test + void combineSentencesWithBufferSizeGreaterThanInputLengthTest() throws NoSuchMethodException { + Method combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + combineSentences.setAccessible(true); + + Integer bufferSize = 1; + List input = List.of(Sentence.builder().content("This").build()); + + assertThatThrownBy(() -> { + try { combineSentences.invoke(semanticChunker, input, bufferSize); } + catch (InvocationTargetException e) { throw e.getCause(); } + }).isInstanceOf(AssertionError.class) + .hasMessage("The buffer size cannot be greater equal than the input length"); + } + + @Test + void combineSentencesWithInputIsNullTest() throws NoSuchMethodException { + Method combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + combineSentences.setAccessible(true); + + Integer bufferSize = 2; + List input = null; + + assertThatThrownBy(() -> { + try { combineSentences.invoke(semanticChunker, input, bufferSize); } + catch (InvocationTargetException e) { throw e.getCause(); } + }) + .isInstanceOf(AssertionError.class) + .hasMessage("The list of sentences cannot be null"); + } + + @Test + void combineSentencesWithInputIsEmptyTest() throws NoSuchMethodException { + Method combineSentences = SemanticChunker.class.getDeclaredMethod("combineSentences", List.class, Integer.class); + combineSentences.setAccessible(true); + + Integer bufferSize = 2; + List input = List.of(); + + assertThatThrownBy(() -> { + try { combineSentences.invoke(semanticChunker, input, bufferSize); } + catch (InvocationTargetException e) { throw e.getCause(); } + }) + .isInstanceOf(AssertionError.class) + .hasMessage("The list of sentences cannot be empty"); + } + + @Test + @SuppressWarnings("unchecked") + void embedSentencesTest() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method embedSentences = SemanticChunker.class.getDeclaredMethod("embedSentences", EmbeddingModel.class, List.class); + embedSentences.setAccessible(true); + + Mockito.when(embeddingModel.embed(Mockito.anyList())) + .thenReturn(List.of(new float[] { 1.0f, 2.0f, 3.0f }, new float[] { 4.0f, 5.0f, 6.0f })); + + List sentences = List.of( + Sentence.builder().combined("This is a test sentence.").build(), + Sentence.builder().combined("How are u?").build() + ); + + List expectedResult = List.of( + Sentence.builder() + .combined("This is a test sentence.") + .embedding(new float[] { 1.0f, 2.0f, 3.0f }) + .build(), + Sentence.builder().combined("How are u?").embedding(new float[] { 4.0f, 5.0f, 6.0f }).build() + ); + + List result = (List) embedSentences.invoke(semanticChunker, embeddingModel, sentences); + + assertThat(result).isNotNull(); + + for (int i = 0; i < result.size(); i++) { + assertThat(result.get(i).getCombined()).isEqualTo(expectedResult.get(i).getCombined()); + assertThat(result.get(i).getEmbedding()).isEqualTo(expectedResult.get(i).getEmbedding()); + } + } + + @Test + void testIdenticalVectors() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method cosineSimilarity = SemanticChunker.class.getDeclaredMethod("cosineSimilarity", float[].class, float[].class); + cosineSimilarity.setAccessible(true); + + float[] embedding1 = new float[] { 1.0f, 2.0f, 3.0f }; + float[] embedding2 = new float[] { 1.0f, 2.0f, 3.0f }; + + double result = (double) cosineSimilarity.invoke(semanticChunker, embedding1, embedding2); + + assertThat(result).isCloseTo(1.0, within(MARGIN)); + } + + @Test + void testOrthogonalVectors() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method cosineSimilarity = SemanticChunker.class.getDeclaredMethod("cosineSimilarity", float[].class, float[].class); + cosineSimilarity.setAccessible(true); + + float[] embedding1 = new float[] { 1.0f, 0.0f, 0.0f }; + float[] embedding2 = new float[] { 0.0f, 1.0f, 0.0f }; + + double result = (double) cosineSimilarity.invoke(semanticChunker, embedding1, embedding2); + + assertThat(result).isEqualTo(0.0); + } + + @Test + void testOppositeVectors() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method cosineSimilarity = SemanticChunker.class.getDeclaredMethod("cosineSimilarity", float[].class, float[].class); + cosineSimilarity.setAccessible(true); + + float[] embedding1 = new float[] { 1.0f, 2.0f, 3.0f }; + float[] embedding2 = new float[] { -1.0f, -2.0f, -3.0f }; + + double result = (double) cosineSimilarity.invoke(semanticChunker, embedding1, embedding2); + + assertThat(result).isCloseTo(-1.0, within(MARGIN)); + } + + @Test + void testDifferentMagnitudeVectors() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method cosineSimilarity = SemanticChunker.class.getDeclaredMethod("cosineSimilarity", float[].class, float[].class); + cosineSimilarity.setAccessible(true); + + float[] embedding1 = new float[] { 1.0f, 2.0f, 3.0f }; + float[] embedding2 = new float[] { 2.0f, 4.0f, 6.0f }; + + double result = (double) cosineSimilarity.invoke(semanticChunker, embedding1, embedding2); + + assertThat(result).isCloseTo(1.0, within(MARGIN)); + } + + @Test + void testZeroVectors() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method cosineSimilarity = SemanticChunker.class.getDeclaredMethod("cosineSimilarity", float[].class, float[].class); + cosineSimilarity.setAccessible(true); + + float[] embedding1 = new float[] { 0.0f, 0.0f, 0.0f }; + float[] embedding2 = new float[] { 0.0f, 0.0f, 0.0f }; + + double result = (double) cosineSimilarity.invoke(semanticChunker, embedding1, embedding2); + + assertThat(result).isNaN(); + } + + @Test + @SuppressWarnings("unchecked") + void testGetIndicesAboveThreshold() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method calculateBreakPoints = SemanticChunker.class.getDeclaredMethod("calculateBreakPoints", List.class, Integer.class); + calculateBreakPoints.setAccessible(true); + + Integer percentile = 95; + List distances = List.of(10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 65.0, 70.0, 75.0); + + List expectedIndices = List.of(13); + + List actualIndices = (List) calculateBreakPoints.invoke(semanticChunker, distances, percentile); + + assertThat(actualIndices).isEqualTo(expectedIndices); + } + + @Test + @SuppressWarnings("unchecked") + void testGenerateChunks() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Method generateChunks = SemanticChunker.class.getDeclaredMethod("generateChunks", List.class, List.class); + generateChunks.setAccessible(true); + + List sentences = List.of( + Sentence.builder().index(0).content("This").build(), + Sentence.builder().index(1).content("is").build(), Sentence.builder().index(2).content("a").build(), + Sentence.builder().index(3).content("test.").build(), Sentence.builder().index(4).content("We").build(), + Sentence.builder().index(5).content("are").build(), + Sentence.builder().index(6).content("writing").build(), + Sentence.builder().index(7).content("unit").build(), + Sentence.builder().index(8).content("tests.").build() + ); + + List breakPoints = List.of(2, 4, 6); + + List expectedChunks = List.of(new Chunk(0, "This is a"), new Chunk(1, "test. We"), + new Chunk(2, "are writing"), new Chunk(3, "unit tests.")); + + List actualChunks = (List) generateChunks.invoke(semanticChunker, sentences, breakPoints); + + assertThat(actualChunks).isNotNull().hasSize(expectedChunks.size()); + + for (int i = 0; i < actualChunks.size(); i++) { + assertThat(actualChunks.get(i).id()).isEqualTo(expectedChunks.get(i).id()); + assertThat(actualChunks.get(i).content()).isEqualTo(expectedChunks.get(i).content()); + } + } + + // @formatter:on + +} diff --git a/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerUtilsTest.java b/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerUtilsTest.java deleted file mode 100644 index 49ff4a5..0000000 --- a/jchunk-semantic/src/test/java/jchunk/chunker/semantic/SemanticChunkerUtilsTest.java +++ /dev/null @@ -1,261 +0,0 @@ -package jchunk.chunker.semantic; - -import jchunk.chunker.core.chunk.Chunk; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; -import org.springframework.ai.embedding.EmbeddingModel; - -import java.util.List; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.within; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; - -class SemanticChunkerUtilsTest { - - final static double MARGIN = 0.0001f; - - final EmbeddingModel embeddingModel; - - SemanticChunkerUtilsTest() { - this.embeddingModel = Mockito.mock(EmbeddingModel.class); - } - - @Test - void splitSentenceDefaultStrategyTest() { - List expectedResult = List.of(Sentence.builder().content("This is a test sentence.").build(), - Sentence.builder().content("How are u?").build(), - Sentence.builder().content("I am fine thanks\nI am a test sentence!").build(), - Sentence.builder().content("sure").build()); - - String content = "This is a test sentence. How are u? I am fine thanks\nI am a test sentence! sure"; - List result = Utils.splitSentences(content, SentenceSplitingStrategy.DEFAULT); - - assertThat(result).isNotNull().hasSize(expectedResult.size()); - - for (int i = 0; i < result.size(); i++) { - assertThat(result.get(i).getContent()).isEqualTo(expectedResult.get(i).getContent()); - } - } - - @Test - void splitSentenceStrategyTest() { - List expectedResult = List.of( - Sentence.builder().content("This is a test sentence. How are u? I am fine thanks").build(), - Sentence.builder().content("I am a test sentence! sure").build()); - - String content = "This is a test sentence. How are u? I am fine thanks\nI am a test sentence! sure"; - List result = Utils.splitSentences(content, SentenceSplitingStrategy.LINE_BREAK); - - assertThat(result).isNotNull().hasSize(expectedResult.size()); - - assertThat(result.get(0).getContent()).isEqualTo(expectedResult.get(0).getContent()); - assertThat(result.get(1).getContent()).isEqualTo(expectedResult.get(1).getContent()); - } - - @Test - void splitSentenceParagraphStrategyTest() { - List expectedResult = List.of(Sentence.builder().index(0).content("This is a test sentence.").build(), - Sentence.builder().index(1).content("How are u? I am fine thanks").build(), - Sentence.builder().index(2).content("I am a test sentence!\nsure").build()); - - String content = "This is a test sentence.\n\nHow are u? I am fine thanks\n\nI am a test sentence!\nsure"; - List result = Utils.splitSentences(content, SentenceSplitingStrategy.PARAGRAPH); - - assertThat(result).isNotNull().hasSize(expectedResult.size()); - - assertThat(result.get(0).getContent()).isEqualTo(expectedResult.get(0).getContent()); - assertThat(result.get(1).getContent()).isEqualTo(expectedResult.get(1).getContent()); - assertThat(result.get(2).getContent()).isEqualTo(expectedResult.get(2).getContent()); - } - - @Test - void combineSentencesSuccessTest() { - Integer bufferSize = 2; - List input = List.of(Sentence.builder().index(0).content("This").build(), - Sentence.builder().index(1).content("is").build(), Sentence.builder().index(2).content("a").build(), - Sentence.builder().index(3).content("sentence").build(), - Sentence.builder().index(4).content("for").build(), Sentence.builder().index(5).content("you").build(), - Sentence.builder().index(6).content("mate").build()); - - List expectedResult = List.of( - Sentence.builder().index(0).content("This").combined("This is a").build(), - Sentence.builder().index(1).content("is").combined("This is a sentence").build(), - Sentence.builder().index(2).content("a").combined("This is a sentence for").build(), - Sentence.builder().index(3).content("sentence").combined("is a sentence for you").build(), - Sentence.builder().index(4).content("for").combined("a sentence for you mate").build(), - Sentence.builder().index(5).content("you").combined("sentence for you mate").build(), - Sentence.builder().index(6).content("mate").combined("for you mate").build()); - - List result = Utils.combineSentences(input, bufferSize); - - assertThat(result).isNotNull(); - assertThat(result.size()).isEqualTo(expectedResult.size()); - - for (int i = 0; i < result.size(); i++) { - assertThat(result.get(i).getIndex()).isEqualTo(expectedResult.get(i).getIndex()); - assertThat(result.get(i).getContent()).isEqualTo(expectedResult.get(i).getContent()); - assertThat(result.get(i).getCombined()).isEqualTo(expectedResult.get(i).getCombined()); - } - } - - @Test - void combineSentencesWithBufferSizeEqualZeroTest() { - Integer bufferSize = 0; - List input = List.of(Sentence.builder().content("This").build()); - - assertThatThrownBy(() -> Utils.combineSentences(input, bufferSize)).isInstanceOf(AssertionError.class) - .hasMessage("The buffer size cannot be null nor 0"); - } - - @Test - void combineSentencesWithBufferSizeIsNullTest() { - Integer bufferSize = null; - List input = List.of(Sentence.builder().content("This").build()); - - assertThatThrownBy(() -> Utils.combineSentences(input, bufferSize)).isInstanceOf(AssertionError.class) - .hasMessage("The buffer size cannot be null nor 0"); - } - - @Test - void combineSentencesWithBufferSizeGreaterThanInputLengthTest() { - Integer bufferSize = 1; - List input = List.of(Sentence.builder().content("This").build()); - - assertThatThrownBy(() -> Utils.combineSentences(input, bufferSize)).isInstanceOf(AssertionError.class) - .hasMessage("The buffer size cannot be greater equal than the input length"); - } - - @Test - void combineSentencesWithInputIsNullTest() { - Integer bufferSize = 2; - List input = null; - - assertThatThrownBy(() -> Utils.combineSentences(input, bufferSize)).isInstanceOf(AssertionError.class) - .hasMessage("The list of sentences cannot be null"); - } - - @Test - void combineSentencesWithInputIsEmptyTest() { - Integer bufferSize = 2; - List input = List.of(); - - assertThatThrownBy(() -> Utils.combineSentences(input, bufferSize)).isInstanceOf(AssertionError.class) - .hasMessage("The list of sentences cannot be empty"); - } - - @Test - void embedSentencesTest() { - Mockito.when(embeddingModel.embed(Mockito.anyList())) - .thenReturn(List.of(new float[] { 1.0f, 2.0f, 3.0f }, new float[] { 4.0f, 5.0f, 6.0f })); - - List sentences = List.of(Sentence.builder().combined("This is a test sentence.").build(), - Sentence.builder().combined("How are u?").build()); - - List expectedResult = List.of( - Sentence.builder() - .combined("This is a test sentence.") - .embedding(new float[] { 1.0f, 2.0f, 3.0f }) - .build(), - Sentence.builder().combined("How are u?").embedding(new float[] { 4.0f, 5.0f, 6.0f }).build()); - - List result = Utils.embedSentences(embeddingModel, sentences); - - assertThat(result).isNotNull(); - - for (int i = 0; i < result.size(); i++) { - assertThat(result.get(i).getCombined()).isEqualTo(expectedResult.get(i).getCombined()); - assertThat(result.get(i).getEmbedding()).isEqualTo(expectedResult.get(i).getEmbedding()); - } - - } - - @Test - void testIdenticalVectors() { - float[] embedding1 = new float[] { 1.0f, 2.0f, 3.0f }; - float[] embedding2 = new float[] { 1.0f, 2.0f, 3.0f }; - - double result = Utils.cosineSimilarity(embedding1, embedding2); - - assertThat(result).isCloseTo(1.0, within(MARGIN)); - } - - @Test - void testOrthogonalVectors() { - float[] embedding1 = new float[] { 1.0f, 0.0f, 0.0f }; - float[] embedding2 = new float[] { 0.0f, 1.0f, 0.0f }; - - double result = Utils.cosineSimilarity(embedding1, embedding2); - - assertThat(result).isEqualTo(0.0); - } - - @Test - void testOppositeVectors() { - float[] embedding1 = new float[] { 1.0f, 2.0f, 3.0f }; - float[] embedding2 = new float[] { -1.0f, -2.0f, -3.0f }; - - double result = Utils.cosineSimilarity(embedding1, embedding2); - - assertThat(result).isCloseTo(-1.0, within(MARGIN)); - } - - @Test - void testDifferentMagnitudeVectors() { - float[] embedding1 = new float[] { 1.0f, 2.0f, 3.0f }; - float[] embedding2 = new float[] { 2.0f, 4.0f, 6.0f }; - - double result = Utils.cosineSimilarity(embedding1, embedding2); - - assertThat(result).isCloseTo(1.0, within(MARGIN)); - } - - @Test - void testZeroVectors() { - float[] embedding1 = new float[] { 0.0f, 0.0f, 0.0f }; - float[] embedding2 = new float[] { 0.0f, 0.0f, 0.0f }; - - double result = Utils.cosineSimilarity(embedding1, embedding2); - - assertThat(result).isNaN(); - } - - @Test - void testGetIndicesAboveThreshold() { - Integer percentile = 95; - List distances = List.of(10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 65.0, 70.0, - 75.0); - - List expectedIndices = List.of(13); - - List actualIndices = Utils.calculateBreakPoints(distances, percentile); - - assertThat(actualIndices).isEqualTo(expectedIndices); - } - - @Test - void testGenerateChunks() { - List sentences = List.of(Sentence.builder().index(0).content("This").build(), - Sentence.builder().index(1).content("is").build(), Sentence.builder().index(2).content("a").build(), - Sentence.builder().index(3).content("test.").build(), Sentence.builder().index(4).content("We").build(), - Sentence.builder().index(5).content("are").build(), - Sentence.builder().index(6).content("writing").build(), - Sentence.builder().index(7).content("unit").build(), - Sentence.builder().index(8).content("tests.").build()); - - List breakPoints = List.of(2, 4, 6); - - List expectedChunks = List.of(new Chunk(0, "This is a"), new Chunk(1, "test. We"), - new Chunk(2, "are writing"), new Chunk(3, "unit tests.")); - - List actualChunks = Utils.generateChunks(sentences, breakPoints); - - assertThat(actualChunks).isNotNull().hasSize(expectedChunks.size()); - - for (int i = 0; i < actualChunks.size(); i++) { - assertThat(actualChunks.get(i).id()).isEqualTo(expectedChunks.get(i).id()); - assertThat(actualChunks.get(i).content()).isEqualTo(expectedChunks.get(i).content()); - } - } - -}