From 4c8a6ee8b0deec797901f1164cb63dc5340f0b84 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Thu, 26 Sep 2024 14:48:11 -0400 Subject: [PATCH] Document BatchingStrategy and enhance TokenCountBatching This commit adds comprehensive documentation for the BatchingStrategy in vector stores and enhances the TokenCountBatchingStrategy class. Key changes: - Explain batching necessity due to embedding model thresholds - Describe BatchingStrategy interface and its purpose - Detail TokenCountBatchingStrategy default implementation - Provide guidance on using and customizing batching strategies - Note pre-configured vector stores with default strategy - Add new constructor for custom TokenCountEstimator in TokenCountBatchingStrategy - Implement null checks with Spring's Assert utility - Update docs with new customization options and code examples --- .../embedding/TokenCountBatchingStrategy.java | 23 ++++ .../modules/ROOT/pages/api/vectordbs.adoc | 109 +++++++++++++++++- 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java index d790a9de49..2ff2dce0e9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java @@ -27,6 +27,7 @@ import org.springframework.ai.tokenizer.TokenCountEstimator; import com.knuddels.jtokkit.api.EncodingType; +import org.springframework.util.Assert; /** * Token count based strategy implementation for {@link BatchingStrategy}. Using openai @@ -96,12 +97,34 @@ public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCo */ public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) { + Assert.notNull(encodingType, "EncodingType must not be null"); + Assert.notNull(contentFormatter, "ContentFormatter must not be null"); + Assert.notNull(metadataMode, "MetadataMode must not be null"); this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType); this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage)); this.contentFormater = contentFormatter; this.metadataMode = metadataMode; } + /** + * Constructs a TokenCountBatchingStrategy with the specified parameters. + * @param tokenCountEstimator the TokenCountEstimator to be used for estimating token + * counts. + * @param maxInputTokenCount the initial upper limit for input tokens. + * @param reservePercentage the percentage of tokens to reserve from the max input + * token count to create a buffer. + * @param contentFormatter the ContentFormatter to be used for formatting content. + * @param metadataMode the MetadataMode to be used for handling metadata. + */ + public TokenCountBatchingStrategy(TokenCountEstimator tokenCountEstimator, int maxInputTokenCount, + double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) { + Assert.notNull(tokenCountEstimator, "TokenCountEstimator must not be null"); + this.tokenCountEstimator = tokenCountEstimator; + this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage)); + this.contentFormater = contentFormatter; + this.metadataMode = metadataMode; + } + @Override public List> batch(List documents) { List> batches = new ArrayList<>(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc index 6e7e0e6697..2230f79695 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc @@ -91,7 +91,114 @@ It will not be initialized for you by default. You must opt-in, by passing a `boolean` for the appropriate constructor argument or, if using Spring Boot, setting the appropriate `initialize-schema` property to `true` in `application.properties` or `application.yml`. Check the documentation for the vector store you are using for the specific property name. -== Available Implementations +== Batching Strategy + +When working with vector stores, it's often necessary to embed large numbers of documents. +While it might seem straightforward to make a single call to embed all documents at once, this approach can lead to issues. +Embedding models process text as tokens and have a maximum token limit, often referred to as the context window size. +This limit restricts the amount of text that can be processed in a single embedding request. +Attempting to embed too many tokens in one call can result in errors or truncated embeddings. + +To address this token limit, Spring AI implements a batching strategy. +This approach breaks down large sets of documents into smaller batches that fit within the embedding model's maximum context window. +Batching not only solves the token limit issue but can also lead to improved performance and more efficient use of API rate limits. + +Spring AI provides this functionality through the `BatchingStrategy` interface, which allows for processing documents in sub-batches based on their token counts. + +The core `BatchingStrategy` interface is defined as follows: + +[source,java] +---- +public interface BatchingStrategy { + List> batch(List documents); +} +---- + +This interface defines a single method, `batch`, which takes a list of documents and returns a list of document batches. + +=== Default Implementation + +Spring AI provides a default implementation called `TokenCountBatchingStrategy`. +This strategy batches documents based on their token counts, ensuring that each batch does not exceed a calculated maximum input token count. + +Key features of `TokenCountBatchingStrategy`: + +1. Uses https://platform.openai.com/docs/guides/embeddings/embedding-models[OpenAI's max input token count] (8191) as the default upper limit. +2. Incorporates a reserve percentage (default 10%) to provide a buffer for potential overhead. +3. Calculates the actual max input token count as: `actualMaxInputTokenCount = originalMaxInputTokenCount * (1 - RESERVE_PERCENTAGE)` + +The strategy estimates the token count for each document, groups them into batches without exceeding the max input token count, and throws an exception if a single document exceeds this limit. + +You can also customize the `TokenCountBatchingStrategy` to better suit your specific requirements. This can be done by creating a new instance with custom parameters in a Spring Boot `@Configuration` class. + +Here's an example of how to create a custom `TokenCountBatchingStrategy` bean: + +[source,java] +---- +@Configuration +public class EmbeddingConfig { + @Bean + public BatchingStrategy customTokenCountBatchingStrategy() { + return new TokenCountBatchingStrategy( + EncodingType.CL100K_BASE, // Specify the encoding type + 8000, // Set the maximum input token count + 0.9 // Set the threshold factor + ); + } +} +---- + +In this configuration: + +1. `EncodingType.CL100K_BASE`: Specifies the encoding type used for tokenization. This encoding type is used by the `JTokkitTokenCountEstimator` to accurately estimate token counts. +2. `8000`: Sets the maximum input token count. This value should be less than or equal to the maximum context window size of your embedding model. +3. `0.9`: Sets the threshold factor. This factor determines how full a batch can be before starting a new one. A value of 0.9 means each batch will be filled up to 90% of the maximum input token count. + +By default, this constructor uses `Document.DEFAULT_CONTENT_FORMATTER` for content formatting and `MetadataMode.NONE` for metadata handling. If you need to customize these parameters, you can use the full constructor with additional parameters. + +Once defined, this custom `TokenCountBatchingStrategy` bean will be automatically used by the `EmbeddingModel` implementations in your application, replacing the default strategy. + +The `TokenCountBatchingStrategy` internally uses a `TokenCountEstimator` (specifically, `JTokkitTokenCountEstimator`) to calculate token counts for efficient batching. This ensures accurate token estimation based on the specified encoding type. + + +Additionally, `TokenCountBatchingStrategy` provides flexibility by allowing you to pass in your own implementation of the `TokenCountEstimator` interface. This feature enables you to use custom token counting strategies tailored to your specific needs. For example: + +[source,java] +---- +TokenCountEstimator customEstimator = new YourCustomTokenCountEstimator(); +TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy( + customEstimator, + 8000, // maxInputTokenCount + 0.1, // reservePercentage + Document.DEFAULT_CONTENT_FORMATTER, + MetadataMode.NONE +); +---- + +=== Custom Implementation + +While `TokenCountBatchingStrategy` provides a robust default implementation, you can customize the batching strategy to fit your specific needs. +This can be done through Spring Boot's auto-configuration. + +To customize the batching strategy, define a `BatchingStrategy` bean in your Spring Boot application: + +[source,java] +---- +@Configuration +public class EmbeddingConfig { + @Bean + public BatchingStrategy customBatchingStrategy() { + return new CustomBatchingStrategy(); + } +} +---- + +This custom `BatchingStrategy` will then be automatically used by the `EmbeddingModel` implementations in your application. + +NOTE: Vector stores supported by Spring AI are configured to use the default `TokenCountBatchingStrategy`. +SAP Hana vector store is not currently configured for batching. + +== VectorStore Implementations These are the available implementations of the `VectorStore` interface: