Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS Bedrock adds exponential backoff support. #759

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-bedrock/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

/**
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat
* generative.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel {
Expand All @@ -45,6 +49,11 @@ public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel

private final AnthropicChatOptions defaultOptions;

/**
* The retry template used to retry the Bedrock API calls.
*/
private final RetryTemplate retryTemplate;

public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) {
this(chatApi,
AnthropicChatOptions.builder()
Expand All @@ -56,35 +65,49 @@ public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) {
}

public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) {
this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(chatApi, "AnthropicChatBedrockApi must not be null");
Assert.notNull(options, "DefaultOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.anthropicChatApi = chatApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);
return this.retryTemplate.execute(ctx -> {
AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.completion())));
return new ChatResponse(List.of(new Generation(response.completion())));
});
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

Flux<AnthropicChatResponse> fluxResponse = this.anthropicChatApi.chatCompletionStream(request);

return fluxResponse.map(response -> {
String stopReason = response.stopReason() != null ? response.stopReason() : null;
var generation = new Generation(response.completion());
if (response.amazonBedrockInvocationMetrics() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics()));
}
return new ChatResponse(List.of(generation));
return this.retryTemplate.execute(ctx -> {
Flux<AnthropicChatResponse> fluxResponse = this.anthropicChatApi.chatCompletionStream(request);

return fluxResponse.map(response -> {
String stopReason = response.stopReason() != null ? response.stopReason() : null;
var generation = new Generation(response.completion());
if (response.amazonBedrockInvocationMetrics() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics()));
}
return new ChatResponse(List.of(generation));
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
Expand All @@ -48,6 +51,7 @@
*
* @author Ben Middleton
* @author Christian Tzolov
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
Expand All @@ -56,6 +60,11 @@ public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel

private final Anthropic3ChatOptions defaultOptions;

/**
* The retry template used to retry the Bedrock API calls.
*/
private final RetryTemplate retryTemplate;

public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) {
this(chatApi,
Anthropic3ChatOptions.builder()
Expand All @@ -67,44 +76,59 @@ public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) {
}

public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(chatApi, "Anthropic3ChatBedrockApi must not be null");
Assert.notNull(options, "DefaultOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.anthropicChatApi = chatApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);
return this.retryTemplate.execute(ctx -> {
AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.content().get(0).text())));
return new ChatResponse(List.of(new Generation(response.content().get(0).text())));
});
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

Flux<Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse> fluxResponse = this.anthropicChatApi
.chatCompletionStream(request);
return this.retryTemplate.execute(ctx -> {
Flux<Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse> fluxResponse = this.anthropicChatApi
.chatCompletionStream(request);

AtomicReference<Integer> inputTokens = new AtomicReference<>(0);
return fluxResponse.map(response -> {
if (response.type() == StreamingType.MESSAGE_START) {
inputTokens.set(response.message().usage().inputTokens());
}
String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : "";
AtomicReference<Integer> inputTokens = new AtomicReference<>(0);
return fluxResponse.map(response -> {
if (response.type() == StreamingType.MESSAGE_START) {
inputTokens.set(response.message().usage().inputTokens());
}
String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : "";

var generation = new Generation(content);
var generation = new Generation(content);

if (response.type() == StreamingType.MESSAGE_DELTA) {
generation = generation.withGenerationMetadata(ChatGenerationMetadata
.from(response.delta().stopReason(), new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(),
response.usage().outputTokens())));
}
if (response.type() == StreamingType.MESSAGE_DELTA) {
generation = generation
.withGenerationMetadata(ChatGenerationMetadata.from(response.delta().stopReason(),
new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(),
response.usage().outputTokens())));
}

return new ChatResponse(List.of(generation));
return new ChatResponse(List.of(generation));
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

/**
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockCohereChatModel implements ChatModel, StreamingChatModel {
Expand All @@ -45,38 +48,61 @@ public class BedrockCohereChatModel implements ChatModel, StreamingChatModel {

private final BedrockCohereChatOptions defaultOptions;

/**
* The retry template used to retry the Bedrock API calls.
*/
private final RetryTemplate retryTemplate;

public BedrockCohereChatModel(CohereChatBedrockApi chatApi) {
this(chatApi, BedrockCohereChatOptions.builder().build());
}

public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) {
this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(chatApi, "CohereChatBedrockApi must not be null");
Assert.notNull(options, "BedrockCohereChatOptions must not be null");
Assert.notNull(options, "DefaultOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.chatApi = chatApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {
CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false));
List<Generation> generations = response.generations().stream().map(g -> {
return new Generation(g.text());
}).toList();

return new ChatResponse(generations);
CohereChatRequest request = this.createRequest(prompt, false);

return this.retryTemplate.execute(ctx -> {
CohereChatResponse response = this.chatApi.chatCompletion(request);

List<Generation> generations = response.generations().stream().map(g -> {
return new Generation(g.text());
}).toList();

return new ChatResponse(generations);
});
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(g -> {
if (g.isFinished()) {
String finishReason = g.finishReason().name();
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
return new ChatResponse(List
.of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage))));
}
return new ChatResponse(List.of(new Generation(g.text())));

CohereChatRequest request = this.createRequest(prompt, true);

return this.retryTemplate.execute(ctx -> {
return this.chatApi.chatCompletionStream(request).map(g -> {
if (g.isFinished()) {
String finishReason = g.finishReason().name();
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
return new ChatResponse(List.of(new Generation("")
.withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage))));
}
return new ChatResponse(List.of(new Generation(g.text())));
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

/**
Expand All @@ -36,6 +38,7 @@
* this API. If this change in the future we will add it as metadata.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel {
Expand All @@ -44,6 +47,11 @@ public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel {

private final BedrockCohereEmbeddingOptions defaultOptions;

/**
* The retry template used to retry the Bedrock API calls.
*/
private final RetryTemplate retryTemplate;

// private CohereEmbeddingRequest.InputType inputType =
// CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT;

Expand All @@ -60,10 +68,18 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr

public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
BedrockCohereEmbeddingOptions options) {
this(cohereEmbeddingBedrockApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
BedrockCohereEmbeddingOptions options, RetryTemplate retryTemplate) {
Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null");
Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null");
Assert.notNull(options, "DefaultOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.embeddingApi = cohereEmbeddingBedrockApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

// /**
Expand Down Expand Up @@ -104,13 +120,16 @@ public EmbeddingResponse call(EmbeddingRequest request) {

var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(),
optionsToUse.getTruncate());
CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest);
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = apiResponse.embeddings()
.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);

return this.retryTemplate.execute(ctx -> {
CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest);
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = apiResponse.embeddings()
.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
});
}

/**
Expand Down
Loading