Skip to content

Commit

Permalink
[ML] Sentence Chunker (elastic#110334)
Browse files Browse the repository at this point in the history
The Sentence chunker splits long text into smaller chunks on sentence boundaries.
  • Loading branch information
davidkyle authored Jul 2, 2024
1 parent cdbe092 commit 6b64389
Show file tree
Hide file tree
Showing 15 changed files with 384 additions and 27 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/110334.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 110334
summary: Sentence Chunker
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;
package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

/**
* Split text into chunks aligned on sentence boundaries.
* The maximum chunk size is measured in words and controlled
* by {@code maxNumberWordsPerChunk}. Sentences are combined
* greedily until adding the next sentence would exceed
* {@code maxNumberWordsPerChunk}, at which point a new chunk
* is created. If an individual sentence is longer than
* {@code maxNumberWordsPerChunk} it is split on word boundary with
* overlap.
*/
public class SentenceBoundaryChunker {

private final BreakIterator sentenceIterator;
private final BreakIterator wordIterator;

public SentenceBoundaryChunker() {
sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT);
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
}

/**
* Break the input text into small chunks on sentence boundaries.
*
* @param input Text to chunk
* @param maxNumberWordsPerChunk Maximum size of the chunk
* @return The input text chunked
*/
public List<String> chunk(String input, int maxNumberWordsPerChunk) {
var chunks = new ArrayList<String>();

sentenceIterator.setText(input);
wordIterator.setText(input);

int chunkStart = 0;
int chunkEnd = 0;
int sentenceStart = 0;
int chunkWordCount = 0;

int boundary = sentenceIterator.next();

while (boundary != BreakIterator.DONE) {
int sentenceEnd = sentenceIterator.current();
int countWordsInSentence = countWords(sentenceStart, sentenceEnd);

if (chunkWordCount + countWordsInSentence > maxNumberWordsPerChunk) {
// over the max chunk size, roll back to the last sentence

if (chunkWordCount > 0) {
// add a new chunk containing all the input up to this sentence
chunks.add(input.substring(chunkStart, chunkEnd));
chunkStart = chunkEnd;
chunkWordCount = countWordsInSentence; // the next chunk will contain this sentence
}

if (countWordsInSentence > maxNumberWordsPerChunk) {
// This sentence is bigger than the max chunk size.
// Split the sentence on the word boundary
var sentenceSplits = splitLongSentence(
input.substring(chunkStart, sentenceEnd),
maxNumberWordsPerChunk,
overlapForChunkSize(maxNumberWordsPerChunk)
);

int i = 0;
for (; i < sentenceSplits.size() - 1; i++) {
// Because the substring was passed to splitLongSentence()
// the returned positions need to be offset by chunkStart
chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end()));
}
// The final split is partially filled.
// Set the next chunk start to the beginning of the
// final split of the long sentence.
chunkStart = chunkStart + sentenceSplits.get(i).start(); // start pos needs to be offset by chunkStart
chunkWordCount = sentenceSplits.get(i).wordCount();
}
} else {
chunkWordCount += countWordsInSentence;
}

sentenceStart = sentenceEnd;
chunkEnd = sentenceEnd;

boundary = sentenceIterator.next();
}

if (chunkWordCount > 0) {
chunks.add(input.substring(chunkStart));
}

return chunks;
}

static List<WordBoundaryChunker.ChunkPosition> splitLongSentence(String text, int maxNumberOfWords, int overlap) {
return new WordBoundaryChunker().chunkPositions(text, maxNumberOfWords, overlap);
}

private int countWords(int start, int end) {
return countWords(start, end, this.wordIterator);
}

// Exposed for testing. wordIterator should have had
// setText() applied before using this function.
static int countWords(int start, int end, BreakIterator wordIterator) {
assert start < end;
wordIterator.preceding(start); // start of the current word

int boundary = wordIterator.current();
int wordCount = 0;
while (boundary != BreakIterator.DONE && boundary <= end) {
int wordStatus = wordIterator.getRuleStatus();
if (wordStatus != BreakIterator.WORD_NONE) {
wordCount++;
}
boundary = wordIterator.next();
}

return wordCount;
}

private static int overlapForChunkSize(int chunkSize) {
return (chunkSize - 1) / 2;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;
package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

Expand All @@ -32,6 +32,8 @@ public WordBoundaryChunker() {
wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
}

record ChunkPosition(int start, int end, int wordCount) {}

/**
* Break the input text into small chunks as dictated
* by the chunking parameters
Expand All @@ -42,6 +44,29 @@ public WordBoundaryChunker() {
* @return List of chunked text
*/
public List<String> chunk(String input, int chunkSize, int overlap) {

if (input.isEmpty()) {
return List.of("");
}

var chunkPositions = chunkPositions(input, chunkSize, overlap);
var chunks = new ArrayList<String>(chunkPositions.size());
for (var pos : chunkPositions) {
chunks.add(input.substring(pos.start, pos.end));
}
return chunks;
}

/**
* Chunk using the same strategy as {@link #chunk(String, int, int)}
* but return the chunk start and end offsets in the {@code input} string
* @param input Text to chunk
* @param chunkSize The number of words in each chunk
* @param overlap The number of words to overlap each chunk.
* Can be 0 but must be non-negative.
* @return List of chunked text positions
*/
List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
if (overlap > 0 && overlap > chunkSize / 2) {
throw new IllegalArgumentException(
"Invalid chunking parameters, overlap ["
Expand All @@ -59,10 +84,10 @@ public List<String> chunk(String input, int chunkSize, int overlap) {
}

if (input.isEmpty()) {
return List.of("");
return List.of();
}

var chunks = new ArrayList<String>();
var chunkPositions = new ArrayList<ChunkPosition>();

// This position in the chunk is where the next overlapping chunk will start
final int chunkSizeLessOverlap = chunkSize - overlap;
Expand All @@ -81,7 +106,7 @@ public List<String> chunk(String input, int chunkSize, int overlap) {
wordsSinceStartWindowWasMarked++;

if (wordsInChunkCountIncludingOverlap >= chunkSize) {
chunks.add(input.substring(windowStart, boundary));
chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap));
wordsInChunkCountIncludingOverlap = overlap;

if (overlap == 0) {
Expand All @@ -102,10 +127,10 @@ public List<String> chunk(String input, int chunkSize, int overlap) {
// Get the last chunk that was shorter than the required chunk size
// if it ends on a boundary than the count should equal overlap in which case
// we can ignore it, unless this is the first chunk in which case we want to add it
if (wordsInChunkCountIncludingOverlap > overlap || chunks.isEmpty()) {
chunks.add(input.substring(windowStart));
if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) {
chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap));
}

return chunks;
return chunkPositions;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;
package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
Expand Down
Loading

0 comments on commit 6b64389

Please sign in to comment.