diff --git a/models/spring-ai-bedrock/pom.xml b/models/spring-ai-bedrock/pom.xml index c7e7d34007..a07eeb0891 100644 --- a/models/spring-ai-bedrock/pom.xml +++ b/models/spring-ai-bedrock/pom.xml @@ -29,6 +29,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.springframework spring-web diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java index b5d62da9db..169c6b3704 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java @@ -31,12 +31,16 @@ import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; /** * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat * generative. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel { @@ -45,6 +49,11 @@ public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel private final AnthropicChatOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) { this(chatApi, AnthropicChatOptions.builder() @@ -56,8 +65,18 @@ public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) { } public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(chatApi, "AnthropicChatBedrockApi must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.anthropicChatApi = chatApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override @@ -65,9 +84,11 @@ public ChatResponse call(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + return this.retryTemplate.execute(ctx -> { + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.completion()))); + return new ChatResponse(List.of(new Generation(response.completion()))); + }); } @Override @@ -75,16 +96,18 @@ public Flux stream(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - Flux fluxResponse = this.anthropicChatApi.chatCompletionStream(request); - - return fluxResponse.map(response -> { - String stopReason = response.stopReason() != null ? response.stopReason() : null; - var generation = new Generation(response.completion()); - if (response.amazonBedrockInvocationMetrics() != null) { - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); - } - return new ChatResponse(List.of(generation)); + return this.retryTemplate.execute(ctx -> { + Flux fluxResponse = this.anthropicChatApi.chatCompletionStream(request); + + return fluxResponse.map(response -> { + String stopReason = response.stopReason() != null ? response.stopReason() : null; + var generation = new Generation(response.completion()); + if (response.amazonBedrockInvocationMetrics() != null) { + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); + } + return new ChatResponse(List.of(generation)); + }); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index 2ab9364e95..cab602fba2 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -40,6 +40,9 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; /** @@ -48,6 +51,7 @@ * * @author Ben Middleton * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel { @@ -56,6 +60,11 @@ public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel private final Anthropic3ChatOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) { this(chatApi, Anthropic3ChatOptions.builder() @@ -67,8 +76,18 @@ public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) { } public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(chatApi, "Anthropic3ChatBedrockApi must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.anthropicChatApi = chatApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override @@ -76,9 +95,11 @@ public ChatResponse call(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + return this.retryTemplate.execute(ctx -> { + AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.content().get(0).text()))); + return new ChatResponse(List.of(new Generation(response.content().get(0).text()))); + }); } @Override @@ -86,25 +107,28 @@ public Flux stream(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - Flux fluxResponse = this.anthropicChatApi - .chatCompletionStream(request); + return this.retryTemplate.execute(ctx -> { + Flux fluxResponse = this.anthropicChatApi + .chatCompletionStream(request); - AtomicReference inputTokens = new AtomicReference<>(0); - return fluxResponse.map(response -> { - if (response.type() == StreamingType.MESSAGE_START) { - inputTokens.set(response.message().usage().inputTokens()); - } - String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : ""; + AtomicReference inputTokens = new AtomicReference<>(0); + return fluxResponse.map(response -> { + if (response.type() == StreamingType.MESSAGE_START) { + inputTokens.set(response.message().usage().inputTokens()); + } + String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : ""; - var generation = new Generation(content); + var generation = new Generation(content); - if (response.type() == StreamingType.MESSAGE_DELTA) { - generation = generation.withGenerationMetadata(ChatGenerationMetadata - .from(response.delta().stopReason(), new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(), - response.usage().outputTokens()))); - } + if (response.type() == StreamingType.MESSAGE_DELTA) { + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from(response.delta().stopReason(), + new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(), + response.usage().outputTokens()))); + } - return new ChatResponse(List.of(generation)); + return new ChatResponse(List.of(generation)); + }); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index 456b7566c6..6b32f0acb2 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -33,10 +33,13 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockCohereChatModel implements ChatModel, StreamingChatModel { @@ -45,38 +48,61 @@ public class BedrockCohereChatModel implements ChatModel, StreamingChatModel { private final BedrockCohereChatOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public BedrockCohereChatModel(CohereChatBedrockApi chatApi) { this(chatApi, BedrockCohereChatOptions.builder().build()); } public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options, + RetryTemplate retryTemplate) { Assert.notNull(chatApi, "CohereChatBedrockApi must not be null"); - Assert.notNull(options, "BedrockCohereChatOptions must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); this.chatApi = chatApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override public ChatResponse call(Prompt prompt) { - CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); - List generations = response.generations().stream().map(g -> { - return new Generation(g.text()); - }).toList(); - return new ChatResponse(generations); + CohereChatRequest request = this.createRequest(prompt, false); + + return this.retryTemplate.execute(ctx -> { + CohereChatResponse response = this.chatApi.chatCompletion(request); + + List generations = response.generations().stream().map(g -> { + return new Generation(g.text()); + }).toList(); + + return new ChatResponse(generations); + }); } @Override public Flux stream(Prompt prompt) { - return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(g -> { - if (g.isFinished()) { - String finishReason = g.finishReason().name(); - Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); - return new ChatResponse(List - .of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); - } - return new ChatResponse(List.of(new Generation(g.text()))); + + CohereChatRequest request = this.createRequest(prompt, true); + + return this.retryTemplate.execute(ctx -> { + return this.chatApi.chatCompletionStream(request).map(g -> { + if (g.isFinished()) { + String finishReason = g.finishReason().name(); + Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); + return new ChatResponse(List.of(new Generation("") + .withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); + } + return new ChatResponse(List.of(new Generation(g.text()))); + }); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java index 25e3f35b43..15858e9919 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java @@ -28,6 +28,8 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -36,6 +38,7 @@ * this API. If this change in the future we will add it as metadata. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel { @@ -44,6 +47,11 @@ public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel { private final BedrockCohereEmbeddingOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + // private CohereEmbeddingRequest.InputType inputType = // CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT; @@ -60,10 +68,18 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi, BedrockCohereEmbeddingOptions options) { + this(cohereEmbeddingBedrockApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi, + BedrockCohereEmbeddingOptions options, RetryTemplate retryTemplate) { Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null"); - Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.embeddingApi = cohereEmbeddingBedrockApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } // /** @@ -104,13 +120,16 @@ public EmbeddingResponse call(EmbeddingRequest request) { var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(), optionsToUse.getTruncate()); - CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); - var indexCounter = new AtomicInteger(0); - List embeddings = apiResponse.embeddings() - .stream() - .map(e -> new Embedding(e, indexCounter.getAndIncrement())) - .toList(); - return new EmbeddingResponse(embeddings); + + return this.retryTemplate.execute(ctx -> { + CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); + var indexCounter = new AtomicInteger(0); + List embeddings = apiResponse.embeddings() + .stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); + }); } /** diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java index 883e55b0e0..74f43d26eb 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java @@ -26,12 +26,15 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** * Java {@link ChatModel} for the Bedrock Jurassic2 chat generative model. * * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ public class BedrockAi21Jurassic2ChatModel implements ChatModel { @@ -40,13 +43,10 @@ public class BedrockAi21Jurassic2ChatModel implements ChatModel { private final BedrockAi21Jurassic2ChatOptions defaultOptions; - public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi, BedrockAi21Jurassic2ChatOptions options) { - Assert.notNull(chatApi, "Ai21Jurassic2ChatBedrockApi must not be null"); - Assert.notNull(options, "BedrockAi21Jurassic2ChatOptions must not be null"); - - this.chatApi = chatApi; - this.defaultOptions = options; - } + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) { this(chatApi, @@ -57,16 +57,34 @@ public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) { .build()); } + public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi, BedrockAi21Jurassic2ChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi, BedrockAi21Jurassic2ChatOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(chatApi, "Ai21Jurassic2ChatBedrockApi must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + + this.chatApi = chatApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + @Override public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); - var response = this.chatApi.chatCompletion(request); - return new ChatResponse(response.completions() - .stream() - .map(completion -> new Generation(completion.data().text()) - .withGenerationMetadata(ChatGenerationMetadata.from(completion.finishReason().reason(), null))) - .toList()); + return this.retryTemplate.execute(ctx -> { + var response = this.chatApi.chatCompletion(request); + + return new ChatResponse(response.completions() + .stream() + .map(completion -> new Generation(completion.data().text()) + .withGenerationMetadata(ChatGenerationMetadata.from(completion.finishReason().reason(), null))) + .toList()); + }); } private Ai21Jurassic2ChatRequest createRequest(Prompt prompt) { @@ -104,6 +122,8 @@ public static class Builder { private BedrockAi21Jurassic2ChatOptions options; + private RetryTemplate retryTemplate; + public Builder(Ai21Jurassic2ChatBedrockApi chatApi) { this.chatApi = chatApi; } @@ -113,9 +133,15 @@ public Builder withOptions(BedrockAi21Jurassic2ChatOptions options) { return this; } + public Builder withRetryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + public BedrockAi21Jurassic2ChatModel build() { return new BedrockAi21Jurassic2ChatModel(chatApi, - options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build()); + options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build(), + retryTemplate != null ? retryTemplate : RetryUtils.DEFAULT_RETRY_TEMPLATE); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java index 3c0634c534..519c4434b0 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java @@ -32,6 +32,8 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -48,17 +50,29 @@ public class BedrockLlamaChatModel implements ChatModel, StreamingChatModel { private final BedrockLlamaChatOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi) { this(chatApi, BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build()); } public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options, + RetryTemplate retryTemplate) { Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null"); - Assert.notNull(options, "BedrockLlamaChatOptions must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); this.chatApi = chatApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override @@ -66,10 +80,12 @@ public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); - LlamaChatResponse response = this.chatApi.chatCompletion(request); + return this.retryTemplate.execute(ctx -> { + LlamaChatResponse response = this.chatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( - ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); + return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( + ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); + }); } @Override @@ -77,12 +93,14 @@ public Flux stream(Prompt prompt) { var request = createRequest(prompt); - Flux fluxResponse = this.chatApi.chatCompletionStream(request); + return this.retryTemplate.execute(ctx -> { + Flux fluxResponse = this.chatApi.chatCompletionStream(request); - return fluxResponse.map(response -> { - String stopReason = response.stopReason() != null ? response.stopReason().name() : null; - return new ChatResponse(List.of(new Generation(response.generation()) - .withGenerationMetadata(ChatGenerationMetadata.from(stopReason, extractUsage(response))))); + return fluxResponse.map(response -> { + String stopReason = response.stopReason() != null ? response.stopReason().name() : null; + return new ChatResponse(List.of(new Generation(response.generation()) + .withGenerationMetadata(ChatGenerationMetadata.from(stopReason, extractUsage(response))))); + }); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index 5712101783..dd1d61664d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -33,10 +33,13 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockTitanChatModel implements ChatModel, StreamingChatModel { @@ -45,45 +48,68 @@ public class BedrockTitanChatModel implements ChatModel, StreamingChatModel { private final BedrockTitanChatOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public BedrockTitanChatModel(TitanChatBedrockApi chatApi) { this(chatApi, BedrockTitanChatOptions.builder().withTemperature(0.8f).build()); } - public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) { - Assert.notNull(chatApi, "ChatApi must not be null"); - Assert.notNull(defaultOptions, "DefaultOptions must not be null"); + public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(chatApi, "TitanChatBedrockApi must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.chatApi = chatApi; - this.defaultOptions = defaultOptions; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override public ChatResponse call(Prompt prompt) { - TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); - List generations = response.results().stream().map(result -> { - return new Generation(result.outputText()); - }).toList(); - return new ChatResponse(generations); + TitanChatRequest request = this.createRequest(prompt); + + return this.retryTemplate.execute(ctx -> { + TitanChatResponse response = this.chatApi.chatCompletion(request); + List generations = response.results().stream().map(result -> { + return new Generation(result.outputText()); + }).toList(); + + return new ChatResponse(generations); + }); } @Override public Flux stream(Prompt prompt) { - return this.chatApi.chatCompletionStream(this.createRequest(prompt)).map(chunk -> { - Generation generation = new Generation(chunk.outputText()); + TitanChatRequest request = this.createRequest(prompt); - if (chunk.amazonBedrockInvocationMetrics() != null) { - String completionReason = chunk.completionReason().name(); - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); - } - else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) { - String completionReason = chunk.completionReason().name(); - generation = generation - .withGenerationMetadata(ChatGenerationMetadata.from(completionReason, extractUsage(chunk))); + return this.retryTemplate.execute(ctx -> { + return this.chatApi.chatCompletionStream(request).map(chunk -> { - } - return new ChatResponse(List.of(generation)); + Generation generation = new Generation(chunk.outputText()); + + if (chunk.amazonBedrockInvocationMetrics() != null) { + String completionReason = chunk.completionReason().name(); + generation = generation.withGenerationMetadata( + ChatGenerationMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); + } + else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) { + String completionReason = chunk.completionReason().name(); + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from(completionReason, extractUsage(chunk))); + + } + return new ChatResponse(List.of(generation)); + }); }); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java index e3089eec5d..979e4c82b6 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java @@ -31,6 +31,8 @@ import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -50,6 +52,11 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { private final TitanEmbeddingBedrockApi embeddingApi; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public enum InputType { TEXT, IMAGE @@ -62,7 +69,15 @@ public enum InputType { private InputType inputType = InputType.TEXT; public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) { + this(titanEmbeddingBedrockApi, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi, RetryTemplate retryTemplate) { + Assert.notNull(titanEmbeddingBedrockApi, "TitanEmbeddingBedrockApi must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.embeddingApi = titanEmbeddingBedrockApi; + this.retryTemplate = retryTemplate; } /** @@ -87,17 +102,19 @@ public EmbeddingResponse call(EmbeddingRequest request) { "Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } - List> embeddingList = new ArrayList<>(); - for (String inputContent : request.getInstructions()) { - var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions()); - TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); - embeddingList.add(response.embedding()); - } - var indexCounter = new AtomicInteger(0); - List embeddings = embeddingList.stream() - .map(e -> new Embedding(e, indexCounter.getAndIncrement())) - .toList(); - return new EmbeddingResponse(embeddings); + return this.retryTemplate.execute(ctx -> { + List> embeddingList = new ArrayList<>(); + for (String inputContent : request.getInstructions()) { + var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions()); + TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); + embeddingList.add(response.embedding()); + } + var indexCounter = new AtomicInteger(0); + List embeddings = embeddingList.stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); + }); } private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index 016ec4306b..d32dcd2259 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -157,7 +157,7 @@ public enum TitanEmbeddingModel { /** * amazon.titan-embed-text-v2 */ - TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; + TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0"); private final String id; diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java index 334efa48ff..3346a941b4 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java @@ -32,7 +32,7 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; -import static org.assertj.core.api.Assertions.assertThat;; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java index 540a6bd2bf..68d2e4a322 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java @@ -30,7 +30,7 @@ import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse.Generation.FinishReason; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy;; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * @author Christian Tzolov diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java index 3e30324546..c06c0546e7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java @@ -28,6 +28,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; @@ -60,8 +62,8 @@ public AnthropicChatBedrockApi anthropicApi(AwsCredentialsProvider credentialsPr @Bean @ConditionalOnBean(AnthropicChatBedrockApi.class) public BedrockAnthropicChatModel anthropicChatModel(AnthropicChatBedrockApi anthropicApi, - BedrockAnthropicChatProperties properties) { - return new BedrockAnthropicChatModel(anthropicApi, properties.getOptions()); + BedrockAnthropicChatProperties properties, RetryTemplate retryTemplate) { + return new BedrockAnthropicChatModel(anthropicApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java index 3e53f026b2..3dfa34dd1a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java @@ -28,6 +28,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; @@ -60,8 +62,8 @@ public Anthropic3ChatBedrockApi anthropic3Api(AwsCredentialsProvider credentials @Bean @ConditionalOnBean(Anthropic3ChatBedrockApi.class) public BedrockAnthropic3ChatModel anthropic3ChatModel(Anthropic3ChatBedrockApi anthropicApi, - BedrockAnthropic3ChatProperties properties) { - return new BedrockAnthropic3ChatModel(anthropicApi, properties.getOptions()); + BedrockAnthropic3ChatProperties properties, RetryTemplate retryTemplate) { + return new BedrockAnthropic3ChatModel(anthropicApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java index 896078e5bc..7a4c98c235 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java @@ -28,6 +28,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; @@ -58,9 +60,9 @@ public CohereChatBedrockApi cohereChatApi(AwsCredentialsProvider credentialsProv @Bean @ConditionalOnBean(CohereChatBedrockApi.class) public BedrockCohereChatModel cohereChatModel(CohereChatBedrockApi cohereChatApi, - BedrockCohereChatProperties properties) { + BedrockCohereChatProperties properties, RetryTemplate retryTemplate) { - return new BedrockCohereChatModel(cohereChatApi, properties.getOptions()); + return new BedrockCohereChatModel(cohereChatApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java index 82b6292de4..f36f3cceef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java @@ -31,6 +31,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Embedding Client. @@ -60,9 +61,9 @@ public CohereEmbeddingBedrockApi cohereEmbeddingApi(AwsCredentialsProvider crede @ConditionalOnMissingBean @ConditionalOnBean(CohereEmbeddingBedrockApi.class) public BedrockCohereEmbeddingModel cohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingApi, - BedrockCohereEmbeddingProperties properties) { + BedrockCohereEmbeddingProperties properties, RetryTemplate retryTemplate) { - return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, properties.getOptions()); + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java index 8ad3c0bb1a..ffb6b7d67d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java @@ -29,6 +29,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; @@ -60,10 +62,11 @@ public Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi(AwsCredentialsPro @Bean @ConditionalOnBean(Ai21Jurassic2ChatBedrockApi.class) public BedrockAi21Jurassic2ChatModel jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi, - BedrockAi21Jurassic2ChatProperties properties) { + BedrockAi21Jurassic2ChatProperties properties, RetryTemplate retryTemplate) { return BedrockAi21Jurassic2ChatModel.builder(ai21Jurassic2ChatBedrockApi) .withOptions(properties.getOptions()) + .withRetryTemplate(retryTemplate) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java index 6e105b8f26..761b63f3fc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java @@ -31,6 +31,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Llama Chat Client. @@ -59,9 +60,10 @@ public LlamaChatBedrockApi llamaApi(AwsCredentialsProvider credentialsProvider, @Bean @ConditionalOnBean(LlamaChatBedrockApi.class) - public BedrockLlamaChatModel llamaChatModel(LlamaChatBedrockApi llamaApi, BedrockLlamaChatProperties properties) { + public BedrockLlamaChatModel llamaChatModel(LlamaChatBedrockApi llamaApi, BedrockLlamaChatProperties properties, + RetryTemplate retryTemplate) { - return new BedrockLlamaChatModel(llamaApi, properties.getOptions()); + return new BedrockLlamaChatModel(llamaApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java index 0115967fe5..8295713bea 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java @@ -28,6 +28,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.providers.AwsRegionProvider; @@ -57,10 +59,10 @@ public TitanChatBedrockApi titanChatBedrockApi(AwsCredentialsProvider credential @Bean @ConditionalOnBean(TitanChatBedrockApi.class) - public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanChatApi, - BedrockTitanChatProperties properties) { + public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanChatApi, BedrockTitanChatProperties properties, + RetryTemplate retryTemplate) { - return new BedrockTitanChatModel(titanChatApi, properties.getOptions()); + return new BedrockTitanChatModel(titanChatApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java index b019dc1c6b..89945bb31d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java @@ -31,6 +31,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Embedding Client. @@ -60,9 +61,10 @@ public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider @ConditionalOnMissingBean @ConditionalOnBean(TitanEmbeddingBedrockApi.class) public BedrockTitanEmbeddingModel titanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingApi, - BedrockTitanEmbeddingProperties properties) { + BedrockTitanEmbeddingProperties properties, RetryTemplate retryTemplate) { - return new BedrockTitanEmbeddingModel(titanEmbeddingApi).withInputType(properties.getInputType()); + return new BedrockTitanEmbeddingModel(titanEmbeddingApi, retryTemplate) + .withInputType(properties.getInputType()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 014ba672b4..501fd4d2e2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -27,6 +27,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; @@ -41,6 +42,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -54,7 +56,8 @@ public class BedrockAnthropicChatAutoConfigurationIT { "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.model=" + AnthropicChatModel.CLAUDE_V2.id(), "spring.ai.bedrock.anthropic.chat.options.temperature=0.5") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockAnthropicChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -106,7 +109,8 @@ public void propertiesTest() { "spring.ai.bedrock.anthropic.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.options.temperature=0.55") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockAnthropicChatAutoConfiguration.class)) .run(context -> { var anthropicChatProperties = context.getBean(BedrockAnthropicChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -135,7 +139,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockAnthropicChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropicChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockAnthropicChatModel.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java index 8cf6d91b07..5d413fa083 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java @@ -27,6 +27,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; @@ -41,6 +42,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -54,7 +56,8 @@ public class BedrockAnthropic3ChatAutoConfigurationIT { "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), "spring.ai.bedrock.anthropic3.chat.model=" + AnthropicChatModel.CLAUDE_V3_SONNET.id(), "spring.ai.bedrock.anthropic3.chat.options.temperature=0.5") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockAnthropic3ChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -106,7 +109,8 @@ public void propertiesTest() { "spring.ai.bedrock.anthropic3.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic3.chat.options.temperature=0.55") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockAnthropic3ChatAutoConfiguration.class)) .run(context -> { var anthropicChatProperties = context.getBean(BedrockAnthropic3ChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -135,7 +139,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.anthropic3.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockAnthropic3ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropic3ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockAnthropic3ChatModel.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index 8150305558..b17a6bbc21 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -28,6 +28,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; @@ -43,6 +44,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -57,7 +59,8 @@ public class BedrockCohereChatAutoConfigurationIT { "spring.ai.bedrock.cohere.chat.model=" + CohereChatModel.COHERE_COMMAND_V14.id(), "spring.ai.bedrock.cohere.chat.options.temperature=0.5", "spring.ai.bedrock.cohere.chat.options.maxTokens=500") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)); + .withConfiguration( + AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, BedrockCohereChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -115,7 +118,8 @@ public void propertiesTest() { "spring.ai.bedrock.cohere.chat.options.numGenerations=3", "spring.ai.bedrock.cohere.chat.options.truncate=START", "spring.ai.bedrock.cohere.chat.options.maxTokens=123") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereChatAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(BedrockCohereChatProperties.class); var aswProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -151,7 +155,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockCohereChatModel.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java index 14d3889551..8709557df9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; @@ -33,6 +34,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -47,7 +49,8 @@ public class BedrockCohereEmbeddingAutoConfigurationIT { "spring.ai.bedrock.cohere.embedding.model=" + CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(), "spring.ai.bedrock.cohere.embedding.options.inputType=SEARCH_DOCUMENT", "spring.ai.bedrock.cohere.embedding.options.truncate=NONE") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)); @Test public void singleEmbedding() { @@ -91,7 +94,8 @@ public void propertiesTest() { "spring.ai.bedrock.cohere.embedding.model=MODEL_XYZ", "spring.ai.bedrock.cohere.embedding.options.inputType=CLASSIFICATION", "spring.ai.bedrock.cohere.embedding.options.truncate=START") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)) .run(context -> { var properties = context.getBean(BedrockCohereEmbeddingProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -121,7 +125,8 @@ public void embeddingDisabled() { // Explicitly enable the embedding auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.embedding.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockCohereEmbeddingModel.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java index 5e92411818..1ebb70ebf9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java @@ -28,6 +28,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -56,7 +57,8 @@ public class BedrockLlamaChatAutoConfigurationIT { "spring.ai.bedrock.llama.chat.model=" + LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), "spring.ai.bedrock.llama.chat.options.temperature=0.5", "spring.ai.bedrock.llama.chat.options.maxGenLen=500") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)); + .withConfiguration( + AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, BedrockLlamaChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -109,7 +111,8 @@ public void propertiesTest() { "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.llama.chat.options.temperature=0.55", "spring.ai.bedrock.llama.chat.options.maxGenLen=123") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockLlamaChatAutoConfiguration.class)) .run(context -> { var llamaChatProperties = context.getBean(BedrockLlamaChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -138,7 +141,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockLlamaChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockLlamaChatModel.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 0783d81494..fb145cec01 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -27,6 +27,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; import org.springframework.ai.chat.Generation; @@ -41,6 +42,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -55,7 +57,8 @@ public class BedrockTitanChatAutoConfigurationIT { "spring.ai.bedrock.titan.chat.model=" + TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), "spring.ai.bedrock.titan.chat.options.temperature=0.5", "spring.ai.bedrock.titan.chat.options.maxTokenCount=500") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)); + .withConfiguration( + AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, BedrockTitanChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -110,7 +113,8 @@ public void propertiesTest() { "spring.ai.bedrock.titan.chat.options.topP=0.55", "spring.ai.bedrock.titan.chat.options.stopSequences=END1,END2", "spring.ai.bedrock.titan.chat.options.maxTokenCount=123") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanChatAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(BedrockTitanChatProperties.class); var aswProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -142,7 +146,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockTitanChatModel.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 5a5a2ad4c1..f3777a218c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; @@ -35,6 +36,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -47,7 +49,8 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), "spring.ai.bedrock.titan.embedding.model=" + TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id()) - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)); @Test public void singleTextEmbedding() { @@ -87,7 +90,8 @@ public void propertiesTest() { "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.titan.embedding.model=MODEL_XYZ", "spring.ai.bedrock.titan.embedding.inputType=TEXT") - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { var properties = context.getBean(BedrockTitanEmbeddingProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -116,7 +120,8 @@ public void embeddingDisabled() { // Explicitly enable the embedding auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockTitanEmbeddingModel.class)).isNotEmpty();