Skip to content

Commit

Permalink
Add Batching strategy for embedding documents
Browse files Browse the repository at this point in the history
 - When embedding documents, allow batching the documents using some criteria.
 - `BatchingStrategy` interface with a `TokenCountBatchingStrategy` implementation that uses
   the openai max input token size of 8191 as the default.
 - Add a default method in EmbeddingModel to embed document using this new batching strategy.
 - Change `MilvusVectorStore` to make use of this new batching API.
 - Adding unit tests for `TokenCountBatchingStrategy`.
 - Adding openai integration test to call the embed API that uses batching.

Resolves spring-projects#1214

Other vector stores will be updated seperately
  • Loading branch information
sobychacko authored and Stuart Charlton committed Aug 21, 2024
1 parent 97fb808 commit e44be0b
Show file tree
Hide file tree
Showing 12 changed files with 4,471 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,23 +18,33 @@
import org.junit.jupiter.api.Test;

import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;

import java.nio.charset.StandardCharsets;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class EmbeddingIT extends AbstractIT {

private Resource resource = new DefaultResourceLoader().getResource("classpath:text_source.txt");

@Autowired
private OpenAiEmbeddingModel embeddingModel;

Expand All @@ -53,6 +63,28 @@ void defaultEmbedding() {
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
}

@Test
void embeddingBatchDocuments() throws Exception {
assertThat(embeddingModel).isNotNull();
List<float[]> embedded = this.embeddingModel.embed(
List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")),
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
new TokenCountBatchingStrategy());
assertThat(embedded.size()).isEqualTo(3);
embedded.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
}

@Test
void embeddingBatchDocumentsThatExceedTheLimit() throws Exception {
assertThat(embeddingModel).isNotNull();
String contentAsString = resource.getContentAsString(StandardCharsets.UTF_8);
assertThatThrownBy(() -> {
embeddingModel.embed(List.of(new Document("Hello World"), new Document(contentAsString)),
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
new TokenCountBatchingStrategy());
}).isInstanceOf(IllegalArgumentException.class);
}

@Test
void embedding3Large() {

Expand Down
4,124 changes: 4,124 additions & 0 deletions models/spring-ai-openai/src/test/resources/text_source.txt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding;

import java.util.List;

import org.springframework.ai.document.Document;

/**
* Contract for batching {@link Document} objects so that the call to embed them could be
* optimized.
*
* @author Soby Chacko
* @since 1.0.0
*/
public interface BatchingStrategy {

/**
* {@link EmbeddingModel} implementations can call this method to optimize embedding
* tokens. The incoming collection of {@link Document}s are split into su-batches.
* @param documents to batch
* @return a list of sub-batches that contain {@link Document}s.
*/
List<List<Document>> batch(List<Document> documents);

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@
import org.springframework.ai.model.Model;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.List;

/**
* EmbeddingModel is a generic interface for embedding models.
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Josh Long
* @author Soby Chacko
* @since 1.0.0
*
*/
public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {

Expand Down Expand Up @@ -61,6 +69,35 @@ default List<float[]> embed(List<String> texts) {
.toList();
}

/**
* Embeds a batch of {@link Document}s into vectors based on a
* {@link BatchingStrategy}.
* @param documents list of {@link Document}s.
* @param options {@link EmbeddingOptions}.
* @param batchingStrategy {@link BatchingStrategy}.
* @return a list of float[] that represents the vectors for the incoming
* {@link Document}s.
*/
default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
Assert.notNull(documents, "Documents must not be null");
List<float[]> embeddings = new ArrayList<>();

List<List<Document>> batch = batchingStrategy.batch(documents);

for (List<Document> subBatch : batch) {
List<String> texts = subBatch.stream().map(Document::getContent).toList();
EmbeddingRequest request = new EmbeddingRequest(texts, options);
EmbeddingResponse response = this.call(request);
for (int i = 0; i < subBatch.size(); i++) {
Document document = subBatch.get(i);
float[] output = response.getResults().get(i).getOutput();
embeddings.add(output);
document.setEmbedding(output);
}
}
return embeddings;
}

/**
* Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}.
* @param texts list of texts to embed.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding;

import java.util.ArrayList;
import java.util.List;

import org.springframework.ai.document.ContentFormatter;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;

import com.knuddels.jtokkit.api.EncodingType;

/**
* Token count based strategy implementation for {@link BatchingStrategy}. Using openai
* max input token as the default:
* https://platform.openai.com/docs/guides/embeddings/embedding-models.
*
* @author Soby Chacko
* @since 1.0.0
*/
public class TokenCountBatchingStrategy implements BatchingStrategy {

/**
* Using openai upper limit of input token count as the default.
*/
private static final int MAX_INPUT_TOKEN_COUNT = 8191;

private final TokenCountEstimator tokenCountEstimator;

private final int maxInputTokenCount;

private final ContentFormatter contentFormater;

private final MetadataMode metadataMode;

public TokenCountBatchingStrategy() {
this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT);
}

/**
* @param encodingType {@link EncodingType}
* @param maxInputTokenCount upper limit for input tokens
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount) {
this(encodingType, maxInputTokenCount, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
}

/**
* @param encodingType {@link EncodingType}
* @param maxInputTokenCount upper limit for input tokens
* @param contentFormatter {@link ContentFormatter}
* @param metadataMode {@link MetadataMode}
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount,
ContentFormatter contentFormatter, MetadataMode metadataMode) {
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * .1));
this.contentFormater = contentFormatter;
this.metadataMode = metadataMode;
}

@Override
public List<List<Document>> batch(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
int currentSize = 0;
List<Document> currentBatch = new ArrayList<>();

for (Document document : documents) {
int tokenCount = this.tokenCountEstimator
.estimate(document.getFormattedContent(this.contentFormater, this.metadataMode));
if (tokenCount > this.maxInputTokenCount) {
throw new IllegalArgumentException(
"Tokens in a single document exceeds the maximum number of allowed input tokens");
}
if (currentSize + tokenCount > maxInputTokenCount) {
batches.add(currentBatch);
currentBatch.clear();
currentSize = 0;
}
currentBatch.add(document);
currentSize += tokenCount;
}
if (!currentBatch.isEmpty()) {
batches.add(currentBatch);
}
return batches;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,27 @@
import com.knuddels.jtokkit.api.EncodingType;

import org.springframework.ai.model.Media;
import org.springframework.ai.model.Content;
import org.springframework.ai.model.MediaContent;
import org.springframework.util.CollectionUtils;

/**
* Estimates the number of tokens in a given text or message using the JTokkit encoding
* library.
*
* @author Christian Tzolov
* @author Soby Chacko
* @since 1.0.0
*/
public class JTokkitTokenCountEstimator implements TokenCountEstimator {

private final Encoding estimator;

public JTokkitTokenCountEstimator() {
this.estimator = Encodings.newLazyEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);
this(EncodingType.CL100K_BASE);
}

public JTokkitTokenCountEstimator(Encoding tokenEncoding) {
this.estimator = tokenEncoding;
public JTokkitTokenCountEstimator(EncodingType tokenEncodingType) {
this.estimator = Encodings.newLazyEncodingRegistry().getEncoding(tokenEncodingType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;

import org.junit.jupiter.api.Test;

import org.springframework.ai.document.Document;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;

/**
* Basic unit test for {@link TokenCountBatchingStrategy}.
*
* @author Soby Chacko
*/
public class TokenCountBatchingStrategyTests {

@Test
void batchEmbeddingHappyPath() {
TokenCountBatchingStrategy tokenCountBatchingStrategy = new TokenCountBatchingStrategy();
List<List<Document>> batch = tokenCountBatchingStrategy.batch(
List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")));
assertThat(batch.size()).isEqualTo(1);
assertThat(batch.get(0).size()).isEqualTo(3);
}

@Test
void batchEmbeddingWithLargeDocumentExceedsMaxTokenSize() throws IOException {
Resource resource = new DefaultResourceLoader().getResource("classpath:text_source.txt");
String contentAsString = resource.getContentAsString(StandardCharsets.UTF_8);
TokenCountBatchingStrategy tokenCountBatchingStrategy = new TokenCountBatchingStrategy();
assertThatThrownBy(() -> {
tokenCountBatchingStrategy.batch(List.of(new Document(contentAsString)));
}).isInstanceOf(IllegalArgumentException.class);
}

}
Loading

0 comments on commit e44be0b

Please sign in to comment.