Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: semantic chunker #26

Merged
merged 2 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package jchunk.chunker.core.decorators;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Annotation to indicate that the visibility of a method, field, or class is increased
* for testing purposes.
*/
@Retention(RetentionPolicy.CLASS)
@Target({ ElementType.METHOD, ElementType.FIELD, ElementType.CONSTRUCTOR, ElementType.TYPE })
public @interface VisibleForTesting {

}
60 changes: 41 additions & 19 deletions jchunk-fixed/src/main/java/jchunk/chunker/fixed/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import jchunk.chunker.core.chunk.Chunk;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -50,7 +48,7 @@ public static List<String> splitIntoSentences(String content, Config config) {
private static List<String> splitWithDelimiter(String content, String delimiter, Config.Delimiter keepDelimiter) {

if (keepDelimiter == Config.Delimiter.NONE) {
return Arrays.stream(content.split(Pattern.quote(delimiter))).filter(s -> !s.isBlank()).toList();
return Arrays.stream(content.split(Pattern.quote(delimiter))).filter(s -> !s.isEmpty()).toList();
}

String withDelimiter = "((?<=%1$s)|(?=%1$s))";
Expand All @@ -70,11 +68,11 @@ private static List<String> splitWithDelimiterStart(List<String> preSplits) {
List<String> splits = new ArrayList<>();

splits.add(preSplits.getFirst());
IntStream.range(1, preSplits.size())
IntStream.range(1, preSplits.size() - 1)
.filter(i -> i % 2 == 1)
.forEach(i -> splits.add(preSplits.get(i).concat(preSplits.get(i + 1))));

return splits.stream().filter(s -> !s.isBlank()).toList();
return splits.stream().filter(s -> !s.isEmpty()).toList();
}

/**
Expand All @@ -91,7 +89,7 @@ private static List<String> splitWithDelimiterEnd(List<String> preSplits) {
.forEach(i -> splits.add(preSplits.get(i).concat(preSplits.get(i + 1))));
splits.add(preSplits.getLast());

return splits.stream().filter(s -> !s.isBlank()).toList();
return splits.stream().filter(s -> !s.isEmpty()).toList();
}

/**
Expand All @@ -110,7 +108,7 @@ static List<Chunk> mergeSentences(List<String> sentences, Config config) {
int delimiterLen = delimiter.length();

List<Chunk> chunks = new ArrayList<>();
List<String> currentChunk = new ArrayList<>();
Deque<String> currentChunk = new LinkedList<>();

AtomicInteger chunkIndex = new AtomicInteger(0);

Expand All @@ -123,14 +121,8 @@ static List<Chunk> mergeSentences(List<String> sentences, Config config) {
}

if (!currentChunk.isEmpty()) {
String generatedSentence = joinSentences(currentChunk, delimiter, trimWhitespace);
chunks.add(new Chunk(chunkIndex.getAndIncrement(), generatedSentence));

while (currentLen > chunkOverlap
|| (currentLen + sentenceLength + (currentChunk.isEmpty() ? 0 : delimiterLen) > chunkSize
&& currentLen > 0)) {
currentLen -= currentChunk.removeFirst().length() + (currentChunk.isEmpty() ? 0 : delimiterLen);
}
addChunk(chunks, currentChunk, delimiter, trimWhitespace, chunkIndex);
currentLen = adjustCurrentChunkForOverlap(currentChunk, currentLen, chunkOverlap, delimiterLen);
}
}

Expand All @@ -139,21 +131,51 @@ static List<Chunk> mergeSentences(List<String> sentences, Config config) {
}

if (!currentChunk.isEmpty()) {
String generatedSentence = joinSentences(currentChunk, config.getDelimiter(), config.getTrimWhitespace());
chunks.add(new Chunk(chunkIndex.getAndIncrement(), generatedSentence));
addChunk(chunks, currentChunk, delimiter, trimWhitespace, chunkIndex);
}

return chunks;
}

/**
* Adds the chunk to the list of chunks.
* @param chunks the list of chunks
* @param currentChunk the current chunk
* @param delimiter the delimiter
* @param trimWhitespace whether to trim the whitespace
* @param index the index of the chunk
*/
private static void addChunk(List<Chunk> chunks, Deque<String> currentChunk, String delimiter,
boolean trimWhitespace, AtomicInteger index) {
String generatedSentence = joinSentences(currentChunk, delimiter, trimWhitespace);
Chunk chunk = Chunk.builder().id(index.getAndIncrement()).content(generatedSentence).build();
chunks.add(chunk);
}

/**
* Adjusts the current chunk for overlap.
* @param currentChunk the current chunk
* @param currentLen the current length of the chunk
* @param chunkOverlap the overlap between chunks
* @param delimiterLen the length of the delimiter
* @return the adjusted length of the chunk
*/
private static int adjustCurrentChunkForOverlap(Deque<String> currentChunk, int currentLen, int chunkOverlap,
int delimiterLen) {
while (currentLen > chunkOverlap && !currentChunk.isEmpty()) {
currentLen -= currentChunk.removeFirst().length() + (currentChunk.isEmpty() ? 0 : delimiterLen);
}
return currentLen;
}

/**
* Joins the sentences into a single sentence.
* @param sentences the sentences to join
* @param delimiter the delimiter to join the sentences
* @param trimWhitespace whether to trim the whitespace
* @return the generated sentence
*/
private static String joinSentences(List<String> sentences, String delimiter, Boolean trimWhitespace) {
private static String joinSentences(Deque<String> sentences, String delimiter, Boolean trimWhitespace) {
String generatedSentence = String.join(delimiter, sentences);
if (trimWhitespace) {
generatedSentence = generatedSentence.trim();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

import jchunk.chunker.core.chunk.Chunk;
import jchunk.chunker.core.chunk.IChunker;
import jchunk.chunker.core.decorators.VisibleForTesting;
import org.nd4j.common.io.Assert;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.ai.embedding.EmbeddingModel;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* A semantic chunker that chunks the content based on the semantic meaning
Expand Down Expand Up @@ -32,12 +40,180 @@ public SemanticChunker(EmbeddingModel embeddingModel, Config config) {

@Override
public List<Chunk> split(String content) {
List<Sentence> sentences = Utils.splitSentences(content, config.getSentenceSplitingStrategy());
sentences = Utils.combineSentences(sentences, config.getBufferSize());
sentences = Utils.embedSentences(embeddingModel, sentences);
List<Double> similarities = Utils.calculateSimilarities(sentences);
List<Integer> breakPoints = Utils.calculateBreakPoints(similarities, config.getPercentile());
return Utils.generateChunks(sentences, breakPoints);
List<Sentence> sentences = splitSentences(content, config.getSentenceSplitingStrategy());
sentences = combineSentences(sentences, config.getBufferSize());
sentences = embedSentences(embeddingModel, sentences);
List<Double> similarities = calculateSimilarities(sentences);
List<Integer> breakPoints = calculateBreakPoints(similarities, config.getPercentile());
return generateChunks(sentences, breakPoints);
}

/**
* Split the content into sentences
* @param content the content to split
* @return the list of sentences
*/
@VisibleForTesting
private List<Sentence> splitSentences(String content, SentenceSplitingStrategy splitingStrategy) {
AtomicInteger index = new AtomicInteger(0);
return Arrays.stream(content.split(splitingStrategy.getStrategy()))
.map(sentence -> Sentence.builder().content(sentence).index(index.getAndIncrement()).build())
.toList();
}

/**
* Combine the sentences based on the buffer size (append the buffer size of sentences
* behind and over the current sentence)
* <p>
* Use the sliding window technique to reduce the time complexity
* @param sentences the list of sentences
* @param bufferSize the buffer size to use
* @return the list of combined sentences
*/
@VisibleForTesting
private List<Sentence> combineSentences(List<Sentence> sentences, Integer bufferSize) {
assert sentences != null : "The list of sentences cannot be null";
assert !sentences.isEmpty() : "The list of sentences cannot be empty";
assert bufferSize != null && bufferSize > 0 : "The buffer size cannot be null nor 0";
assert bufferSize < sentences.size() : "The buffer size cannot be greater equal than the input length";

int n = sentences.size();
int windowSize = bufferSize * 2 + 1;
int currentWindowSize = 0;
StringBuilder windowBuilder = new StringBuilder();

for (int i = 0; i <= Math.min(bufferSize, n - 1); i++) {
windowBuilder.append(sentences.get(i).getContent()).append(" ");
currentWindowSize++;
}

windowBuilder.deleteCharAt(windowBuilder.length() - 1);

for (int i = 0; i < n; ++i) {
sentences.get(i).setCombined(windowBuilder.toString());

if (currentWindowSize < windowSize && i + bufferSize + 1 < n) {
windowBuilder.append(" ").append(sentences.get(i + bufferSize + 1).getContent());
currentWindowSize++;
}
else {
windowBuilder.delete(0, sentences.get(i - bufferSize).getContent().length() + 1);
if (i + bufferSize + 1 < n) {
windowBuilder.append(" ").append(sentences.get(i + bufferSize + 1).getContent());
}
else {
currentWindowSize--;
}
}
}

return sentences;
}

/**
* Embed the sentences using the embedding model
* @param sentences the list of sentences
* @return the list of sentences with the embeddings
*/
@VisibleForTesting
private List<Sentence> embedSentences(EmbeddingModel embeddingModel, List<Sentence> sentences) {

List<String> sentencesText = sentences.stream().map(Sentence::getContent).toList();

List<float[]> embeddings = embeddingModel.embed(sentencesText);

return IntStream.range(0, sentences.size()).mapToObj(i -> {
Sentence sentence = sentences.get(i);
sentence.setEmbedding(embeddings.get(i));
return sentence;
}).toList();
}

/**
* Calculate the similarity between the sentences embeddings
* @param sentence1 the first sentence embedding
* @param sentence2 the second sentence embedding
* @return the cosine similarity between the sentences
*/
@VisibleForTesting
private Double cosineSimilarity(float[] sentence1, float[] sentence2) {
assert sentence1 != null : "The first sentence embedding cannot be null";
assert sentence2 != null : "The second sentence embedding cannot be null";
assert sentence1.length == sentence2.length : "The sentence embeddings must have the same size";

INDArray arrayA = Nd4j.create(sentence1);
INDArray arrayB = Nd4j.create(sentence2);

arrayA = arrayA.div(arrayA.norm2Number());
arrayB = arrayB.div(arrayB.norm2Number());

return Nd4j.getBlasWrapper().dot(arrayA, arrayB);
}

/**
* Calculate the similarity between the sentences embeddings
* @param sentences the list of sentences
* @return the list of similarities (List of double)
*/
@VisibleForTesting
private List<Double> calculateSimilarities(List<Sentence> sentences) {
return IntStream.range(0, sentences.size() - 1).parallel().mapToObj(i -> {
Sentence sentence1 = sentences.get(i);
Sentence sentence2 = sentences.get(i + 1);
return cosineSimilarity(sentence1.getEmbedding(), sentence2.getEmbedding());
}).toList();
}

/**
* Calculate the break points indices based on the similarities and the threshold
* @param distances the list of cosine similarities between the sentences
* @return the list of break points indices
*/
@VisibleForTesting
private List<Integer> calculateBreakPoints(List<Double> distances, Integer percentile) {
Assert.isTrue(distances != null, "The list of distances cannot be null");

double breakpointDistanceThreshold = calculatePercentile(distances, percentile);

return IntStream.range(0, distances.size())
.filter(i -> distances.get(i) >= breakpointDistanceThreshold)
.boxed()
.toList();
}

private static Double calculatePercentile(List<Double> distances, int percentile) {
Assert.isTrue(distances != null, "The list of distances cannot be null");
Assert.isTrue(percentile > 0 && percentile < 100, "The percentile must be between 0 and 100");

distances = distances.stream().sorted().toList();

int rank = (int) Math.ceil(percentile / 100.0 * distances.size());
return distances.get(rank - 1);
}

/**
* Generate chunks combining the sentences based on the break points
* @param sentences the list of sentences
* @param breakPoints the list of break points indices
* @return the list of chunks
*/
@VisibleForTesting
private List<Chunk> generateChunks(List<Sentence> sentences, List<Integer> breakPoints) {
Assert.isTrue(sentences != null, "The list of sentences cannot be null");
Assert.isTrue(!sentences.isEmpty(), "The list of sentences cannot be empty");
Assert.isTrue(breakPoints != null, "The list of break points cannot be null");

AtomicInteger index = new AtomicInteger(0);

return IntStream.range(0, breakPoints.size() + 1).mapToObj(i -> {
int start = i == 0 ? 0 : breakPoints.get(i - 1) + 1;
int end = i == breakPoints.size() ? sentences.size() : breakPoints.get(i) + 1;
String content = sentences.subList(start, end)
.stream()
.map(Sentence::getContent)
.collect(Collectors.joining(" "));
return new Chunk(index.getAndIncrement(), content);
}).toList();
}

}
Loading
Loading