Skip to content

Commit

Permalink
Document BatchingStrategy and enhance TokenCountBatching
Browse files Browse the repository at this point in the history
This commit adds comprehensive documentation for the BatchingStrategy
in vector stores and enhances the TokenCountBatchingStrategy class.

Key changes:
- Explain batching necessity due to embedding model thresholds
- Describe BatchingStrategy interface and its purpose
- Detail TokenCountBatchingStrategy default implementation
- Provide guidance on using and customizing batching strategies
- Note pre-configured vector stores with default strategy
- Add new constructor for custom TokenCountEstimator in
  TokenCountBatchingStrategy
- Implement null checks with Spring's Assert utility
- Update docs with new customization options and code examples
  • Loading branch information
sobychacko authored and Mark Pollack committed Sep 28, 2024
1 parent e29d38d commit 4c8a6ee
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.tokenizer.TokenCountEstimator;

import com.knuddels.jtokkit.api.EncodingType;
import org.springframework.util.Assert;

/**
* Token count based strategy implementation for {@link BatchingStrategy}. Using openai
Expand Down Expand Up @@ -96,12 +97,34 @@ public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCo
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage,
ContentFormatter contentFormatter, MetadataMode metadataMode) {
Assert.notNull(encodingType, "EncodingType must not be null");
Assert.notNull(contentFormatter, "ContentFormatter must not be null");
Assert.notNull(metadataMode, "MetadataMode must not be null");
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
this.contentFormater = contentFormatter;
this.metadataMode = metadataMode;
}

/**
* Constructs a TokenCountBatchingStrategy with the specified parameters.
* @param tokenCountEstimator the TokenCountEstimator to be used for estimating token
* counts.
* @param maxInputTokenCount the initial upper limit for input tokens.
* @param reservePercentage the percentage of tokens to reserve from the max input
* token count to create a buffer.
* @param contentFormatter the ContentFormatter to be used for formatting content.
* @param metadataMode the MetadataMode to be used for handling metadata.
*/
public TokenCountBatchingStrategy(TokenCountEstimator tokenCountEstimator, int maxInputTokenCount,
double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) {
Assert.notNull(tokenCountEstimator, "TokenCountEstimator must not be null");
this.tokenCountEstimator = tokenCountEstimator;
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
this.contentFormater = contentFormatter;
this.metadataMode = metadataMode;
}

@Override
public List<List<Document>> batch(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
Expand Down
109 changes: 108 additions & 1 deletion spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,114 @@ It will not be initialized for you by default.
You must opt-in, by passing a `boolean` for the appropriate constructor argument or, if using Spring Boot, setting the appropriate `initialize-schema` property to `true` in `application.properties` or `application.yml`.
Check the documentation for the vector store you are using for the specific property name.

== Available Implementations
== Batching Strategy

When working with vector stores, it's often necessary to embed large numbers of documents.
While it might seem straightforward to make a single call to embed all documents at once, this approach can lead to issues.
Embedding models process text as tokens and have a maximum token limit, often referred to as the context window size.
This limit restricts the amount of text that can be processed in a single embedding request.
Attempting to embed too many tokens in one call can result in errors or truncated embeddings.

To address this token limit, Spring AI implements a batching strategy.
This approach breaks down large sets of documents into smaller batches that fit within the embedding model's maximum context window.
Batching not only solves the token limit issue but can also lead to improved performance and more efficient use of API rate limits.

Spring AI provides this functionality through the `BatchingStrategy` interface, which allows for processing documents in sub-batches based on their token counts.

The core `BatchingStrategy` interface is defined as follows:

[source,java]
----
public interface BatchingStrategy {
List<List<Document>> batch(List<Document> documents);
}
----

This interface defines a single method, `batch`, which takes a list of documents and returns a list of document batches.

=== Default Implementation

Spring AI provides a default implementation called `TokenCountBatchingStrategy`.
This strategy batches documents based on their token counts, ensuring that each batch does not exceed a calculated maximum input token count.

Key features of `TokenCountBatchingStrategy`:

1. Uses https://platform.openai.com/docs/guides/embeddings/embedding-models[OpenAI's max input token count] (8191) as the default upper limit.
2. Incorporates a reserve percentage (default 10%) to provide a buffer for potential overhead.
3. Calculates the actual max input token count as: `actualMaxInputTokenCount = originalMaxInputTokenCount * (1 - RESERVE_PERCENTAGE)`

The strategy estimates the token count for each document, groups them into batches without exceeding the max input token count, and throws an exception if a single document exceeds this limit.

You can also customize the `TokenCountBatchingStrategy` to better suit your specific requirements. This can be done by creating a new instance with custom parameters in a Spring Boot `@Configuration` class.

Here's an example of how to create a custom `TokenCountBatchingStrategy` bean:

[source,java]
----
@Configuration
public class EmbeddingConfig {
@Bean
public BatchingStrategy customTokenCountBatchingStrategy() {
return new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE, // Specify the encoding type
8000, // Set the maximum input token count
0.9 // Set the threshold factor
);
}
}
----

In this configuration:

1. `EncodingType.CL100K_BASE`: Specifies the encoding type used for tokenization. This encoding type is used by the `JTokkitTokenCountEstimator` to accurately estimate token counts.
2. `8000`: Sets the maximum input token count. This value should be less than or equal to the maximum context window size of your embedding model.
3. `0.9`: Sets the threshold factor. This factor determines how full a batch can be before starting a new one. A value of 0.9 means each batch will be filled up to 90% of the maximum input token count.

By default, this constructor uses `Document.DEFAULT_CONTENT_FORMATTER` for content formatting and `MetadataMode.NONE` for metadata handling. If you need to customize these parameters, you can use the full constructor with additional parameters.

Once defined, this custom `TokenCountBatchingStrategy` bean will be automatically used by the `EmbeddingModel` implementations in your application, replacing the default strategy.

The `TokenCountBatchingStrategy` internally uses a `TokenCountEstimator` (specifically, `JTokkitTokenCountEstimator`) to calculate token counts for efficient batching. This ensures accurate token estimation based on the specified encoding type.


Additionally, `TokenCountBatchingStrategy` provides flexibility by allowing you to pass in your own implementation of the `TokenCountEstimator` interface. This feature enables you to use custom token counting strategies tailored to your specific needs. For example:

[source,java]
----
TokenCountEstimator customEstimator = new YourCustomTokenCountEstimator();
TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy(
customEstimator,
8000, // maxInputTokenCount
0.1, // reservePercentage
Document.DEFAULT_CONTENT_FORMATTER,
MetadataMode.NONE
);
----

=== Custom Implementation

While `TokenCountBatchingStrategy` provides a robust default implementation, you can customize the batching strategy to fit your specific needs.
This can be done through Spring Boot's auto-configuration.

To customize the batching strategy, define a `BatchingStrategy` bean in your Spring Boot application:

[source,java]
----
@Configuration
public class EmbeddingConfig {
@Bean
public BatchingStrategy customBatchingStrategy() {
return new CustomBatchingStrategy();
}
}
----

This custom `BatchingStrategy` will then be automatically used by the `EmbeddingModel` implementations in your application.

NOTE: Vector stores supported by Spring AI are configured to use the default `TokenCountBatchingStrategy`.
SAP Hana vector store is not currently configured for batching.

== VectorStore Implementations

These are the available implementations of the `VectorStore` interface:

Expand Down

0 comments on commit 4c8a6ee

Please sign in to comment.