Skip to content

Commit

Permalink
Enhance TokenCountBatchingStrategy with reserve percentage
Browse files Browse the repository at this point in the history
- Precompute document token counts before batching into List<List<Document>>
- Introduce configurable reserve percentage for max input token count

Resolves #1260
  • Loading branch information
sobychacko authored and markpollack committed Sep 4, 2024
1 parent 3cab5bd commit 73d0b30
Showing 1 changed file with 41 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
package org.springframework.ai.embedding;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.springframework.ai.document.ContentFormatter;
import org.springframework.ai.document.Document;
Expand All @@ -31,7 +33,19 @@
* max input token as the default:
* https://platform.openai.com/docs/guides/embeddings/embedding-models.
*
* This strategy incorporates a reserve percentage to provide a buffer for potential
* overhead or unexpected increases in token count during processing. The actual max input
* token count used is calculated as: actualMaxInputTokenCount =
* originalMaxInputTokenCount * (1 - RESERVE_PERCENTAGE)
*
* For example, with the default reserve percentage of 10% (0.1) and the default max input
* token count of 8191, the actual max input token count used will be 7371.
*
* The strategy batches documents based on their token counts, ensuring that each batch
* does not exceed the calculated max input token count.
*
* @author Soby Chacko
* @author Mark Pollack
* @since 1.0.0
*/
public class TokenCountBatchingStrategy implements BatchingStrategy {
Expand All @@ -41,6 +55,12 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
*/
private static final int MAX_INPUT_TOKEN_COUNT = 8191;

/**
* The default percentage of tokens to reserve when calculating the actual max input
* token count.
*/
private static final double DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE = 0.1;

private final TokenCountEstimator tokenCountEstimator;

private final int maxInputTokenCount;
Expand All @@ -50,27 +70,33 @@ public class TokenCountBatchingStrategy implements BatchingStrategy {
private final MetadataMode metadataMode;

public TokenCountBatchingStrategy() {
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT);
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT, DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE);
}

/**
* @param encodingType {@link EncodingType}
* @param thresholdFactor the threshold factor to use on top of the max input token
* count
* @param maxInputTokenCount upper limit for input tokens
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount) {
this(encodingType, maxInputTokenCount, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor) {
this(encodingType, maxInputTokenCount, thresholdFactor, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
}

/**
* @param encodingType {@link EncodingType}
* @param maxInputTokenCount upper limit for input tokens
* @param contentFormatter {@link ContentFormatter}
* @param metadataMode {@link MetadataMode}
* @param encodingType The {@link EncodingType} to be used for token counting.
* @param maxInputTokenCount The initial upper limit for input tokens.
* @param reservePercentage The percentage of tokens to reserve from the max input
* token count. This creates a buffer for potential token count increases during
* processing.
* @param contentFormatter the {@link ContentFormatter} to be used for formatting
* content.
* @param metadataMode The {@link MetadataMode} to be used for handling metadata.
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount,
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage,
ContentFormatter contentFormatter, MetadataMode metadataMode) {
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * .1));
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
this.contentFormater = contentFormatter;
this.metadataMode = metadataMode;
}
Expand All @@ -80,6 +106,7 @@ public List<List<Document>> batch(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
int currentSize = 0;
List<Document> currentBatch = new ArrayList<>();
Map<Document, Integer> documentTokens = new HashMap<>();

for (Document document : documents) {
int tokenCount = this.tokenCountEstimator
Expand All @@ -88,6 +115,11 @@ public List<List<Document>> batch(List<Document> documents) {
throw new IllegalArgumentException(
"Tokens in a single document exceeds the maximum number of allowed input tokens");
}
documentTokens.put(document, tokenCount);
}

for (Document document : documentTokens.keySet()) {
Integer tokenCount = documentTokens.get(document);
if (currentSize + tokenCount > maxInputTokenCount) {
batches.add(currentBatch);
currentBatch.clear();
Expand Down

0 comments on commit 73d0b30

Please sign in to comment.