diff --git a/README.md b/README.md index cfe1218382..44c5499dee 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i Spring AI supports many AI models. For an overview see here. Specific models currently supported are * OpenAI * Azure OpenAI -* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2) +* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral) * Hugging Face * Google VertexAI (PaLM2, Gemini) * Mistral AI diff --git a/models/spring-ai-bedrock/README.md b/models/spring-ai-bedrock/README.md index 19e48518a6..782af7a853 100644 --- a/models/spring-ai-bedrock/README.md +++ b/models/spring-ai-bedrock/README.md @@ -8,4 +8,5 @@ - [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html) - [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html) - [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html) +- [Mistral Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-mistral.html) diff --git a/models/spring-ai-bedrock/pom.xml b/models/spring-ai-bedrock/pom.xml index e3b79d30bd..ae7913822a 100644 --- a/models/spring-ai-bedrock/pom.xml +++ b/models/spring-ai-bedrock/pom.xml @@ -29,6 +29,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.springframework spring-web diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockChatResponseMetadata.java new file mode 100644 index 0000000000..27e745cbcf --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockChatResponseMetadata.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock; + +import java.util.HashMap; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; + +import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics; + +/** + * {@link ChatResponseMetadata} implementation for {@literal Amazon Bedrock}. + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockChatResponseMetadata extends HashMap implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, latency: %4$sms}"; + + private final String id; + + private final Usage usage; + + private final Long latencyMs; + + public static BedrockChatResponseMetadata from(ConverseResponse response) { + String requestId = response.responseMetadata().requestId(); + + BedrockUsage usage = BedrockUsage.from(response.usage()); + + ConverseMetrics metrics = response.metrics(); + + return new BedrockChatResponseMetadata(requestId, usage, metrics.latencyMs()); + } + + public static BedrockChatResponseMetadata from(ConverseStreamMetadataEvent converseStreamMetadataEvent) { + BedrockUsage usage = BedrockUsage.from(converseStreamMetadataEvent.usage()); + + ConverseStreamMetrics metrics = converseStreamMetadataEvent.metrics(); + + return new BedrockChatResponseMetadata(null, usage, metrics.latencyMs()); + } + + protected BedrockChatResponseMetadata(String id, BedrockUsage usage, Long latencyMs) { + this.id = id; + this.usage = usage; + this.latencyMs = latencyMs; + } + + public String getId() { + return this.id; + } + + public Long getLatencyMs() { + return latencyMs; + } + + @Override + public Usage getUsage() { + Usage usage = this.usage; + return usage != null ? usage : new EmptyUsage(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getLatencyMs()); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockConverseChatGenerationMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockConverseChatGenerationMetadata.java new file mode 100644 index 0000000000..f1088a644c --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockConverseChatGenerationMetadata.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock; + +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; + +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; + +/** + * Amazon Bedrock Chat model converse interface generation metadata, encapsulating + * information on the completion. + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata { + + private String stopReason; + + private Message message; + + private ConverseStreamOutput event; + + public BedrockConverseChatGenerationMetadata(String stopReason, ConverseStreamOutput event) { + super(); + + this.stopReason = stopReason; + this.event = event; + } + + public BedrockConverseChatGenerationMetadata(String stopReason, Message message) { + super(); + + this.stopReason = stopReason; + this.message = message; + } + + public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, Message message) { + return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), message); + } + + public static BedrockConverseChatGenerationMetadata from(ConverseStreamOutput event) { + String stopReason = null; + + if (event instanceof MessageStopEvent messageStopEvent) { + stopReason = messageStopEvent.stopReasonAsString(); + } + + return new BedrockConverseChatGenerationMetadata(stopReason, event); + } + + @Override + public T getContentFilterMetadata() { + return null; + } + + @Override + public String getFinishReason() { + return stopReason; + } + + public Message getMessage() { + return message; + } + + public ConverseStreamOutput getEvent() { + return event; + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java index 6eb31aa28e..6271e12715 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java @@ -19,42 +19,49 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; +import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; + /** * {@link Usage} implementation for Bedrock API. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockUsage implements Usage { public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { - return new BedrockUsage(usage); + return new BedrockUsage(usage.inputTokenCount().longValue(), usage.outputTokenCount().longValue()); } - private final AmazonBedrockInvocationMetrics usage; + public static BedrockUsage from(TokenUsage usage) { + Assert.notNull(usage, "'TokenUsage' must not be null."); - protected BedrockUsage(AmazonBedrockInvocationMetrics usage) { - Assert.notNull(usage, "OpenAI Usage must not be null"); - this.usage = usage; + return new BedrockUsage(usage.inputTokens().longValue(), usage.outputTokens().longValue()); } - protected AmazonBedrockInvocationMetrics getUsage() { - return this.usage; + private final Long inputTokens; + + private final Long outputTokens; + + protected BedrockUsage(Long inputTokens, Long outputTokens) { + this.inputTokens = inputTokens; + this.outputTokens = outputTokens; } @Override public Long getPromptTokens() { - return getUsage().inputTokenCount().longValue(); + return inputTokens; } @Override public Long getGenerationTokens() { - return getUsage().outputTokenCount().longValue(); + return outputTokens; } @Override public String toString() { - return getUsage().toString(); + return "BedrockUsage [inputTokens=" + inputTokens + ", outputTokens=" + outputTokens + "]"; } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java index d0d5a5a2ce..57b6f4ccb3 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java @@ -25,7 +25,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; /** + * Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html + * * @author Christian Tzolov + * @author Wei Jiang */ @JsonInclude(Include.NON_NULL) public class AnthropicChatOptions implements ChatOptions { @@ -44,7 +48,7 @@ public class AnthropicChatOptions implements ChatOptions { * reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We * recommend a limit of 4,000 tokens for optimal performance. */ - private @JsonProperty("max_tokens_to_sample") Integer maxTokensToSample; + private @JsonProperty("max_tokens") Integer maxTokens; /** * Specify the number of token choices the generative uses to generate the next token. @@ -62,11 +66,6 @@ public class AnthropicChatOptions implements ChatOptions { * generating further tokens. The returned text doesn't contain the stop sequence. */ private @JsonProperty("stop_sequences") List stopSequences; - - /** - * The version of the generative to use. The default value is bedrock-2023-05-31. - */ - private @JsonProperty("anthropic_version") String anthropicVersion; // @formatter:on public static Builder builder() { @@ -82,8 +81,8 @@ public Builder withTemperature(Float temperature) { return this; } - public Builder withMaxTokensToSample(Integer maxTokensToSample) { - this.options.setMaxTokensToSample(maxTokensToSample); + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); return this; } @@ -102,11 +101,6 @@ public Builder withStopSequences(List stopSequences) { return this; } - public Builder withAnthropicVersion(String anthropicVersion) { - this.options.setAnthropicVersion(anthropicVersion); - return this; - } - public AnthropicChatOptions build() { return this.options; } @@ -122,12 +116,12 @@ public void setTemperature(Float temperature) { this.temperature = temperature; } - public Integer getMaxTokensToSample() { - return this.maxTokensToSample; + public Integer getMaxTokens() { + return maxTokens; } - public void setMaxTokensToSample(Integer maxTokensToSample) { - this.maxTokensToSample = maxTokensToSample; + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; } @Override @@ -156,21 +150,12 @@ public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } - public String getAnthropicVersion() { - return this.anthropicVersion; - } - - public void setAnthropicVersion(String anthropicVersion) { - this.anthropicVersion = anthropicVersion; - } - public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) { return builder().withTemperature(fromOptions.getTemperature()) - .withMaxTokensToSample(fromOptions.getMaxTokensToSample()) + .withMaxTokens(fromOptions.getMaxTokens()) .withTopK(fromOptions.getTopK()) .withTopP(fromOptions.getTopP()) .withStopSequences(fromOptions.getStopSequences()) - .withAnthropicVersion(fromOptions.getAnthropicVersion()) .build(); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java index d6a44d9bf6..7c561349b3 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java @@ -15,105 +15,117 @@ */ package org.springframework.ai.bedrock.anthropic; -import java.util.List; - import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; -import org.springframework.ai.bedrock.MessageToPromptConverter; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.ModelDescription; +import org.springframework.util.Assert; /** * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat - * generative. + * generative model. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel { - private final AnthropicChatBedrockApi anthropicChatApi; + private final String modelId; + + private final BedrockConverseApi converseApi; private final AnthropicChatOptions defaultOptions; - public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) { - this(chatApi, - AnthropicChatOptions.builder() - .withTemperature(0.8f) - .withMaxTokensToSample(500) - .withTopK(10) - .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) - .build()); + public BedrockAnthropicChatModel(BedrockConverseApi converseApi) { + this(converseApi, AnthropicChatOptions.builder().withTemperature(0.8f).withTopK(10).build()); } - public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) { - this.anthropicChatApi = chatApi; + public BedrockAnthropicChatModel(BedrockConverseApi converseApi, AnthropicChatOptions options) { + this(AnthropicChatModel.CLAUDE_V2.id(), converseApi, options); + } + + public BedrockAnthropicChatModel(String modelId, BedrockConverseApi converseApi, AnthropicChatOptions options) { + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(options, "AnthropicChatOptions must not be null."); + + this.modelId = modelId; + this.converseApi = converseApi; this.defaultOptions = options; } @Override public ChatResponse call(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - AnthropicChatRequest request = createRequest(prompt); + var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions); - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + ConverseResponse response = this.converseApi.converse(request); - return new ChatResponse(List.of(new Generation(response.completion()))); + return BedrockConverseApiUtils.convertConverseResponse(response); } @Override public Flux stream(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - AnthropicChatRequest request = createRequest(prompt); + var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions); - Flux fluxResponse = this.anthropicChatApi.chatCompletionStream(request); + Flux fluxResponse = this.converseApi.converseStream(request); - return fluxResponse.map(response -> { - String stopReason = response.stopReason() != null ? response.stopReason() : null; - var generation = new Generation(response.completion()); - if (response.amazonBedrockInvocationMetrics() != null) { - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics())); - } - return new ChatResponse(List.of(generation)); - }); + return fluxResponse.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output)); + } + + @Override + public ChatOptions getDefaultOptions() { + return AnthropicChatOptions.fromOptions(this.defaultOptions); } /** - * Accessible for testing. + * Anthropic models version. */ - AnthropicChatRequest createRequest(Prompt prompt) { - - // Related to: https://github.com/spring-projects/spring-ai/issues/404 - final String promptValue = MessageToPromptConverter.create("\n").toPrompt(prompt.getInstructions()); - - AnthropicChatRequest request = AnthropicChatRequest.builder(promptValue).build(); - - if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, AnthropicChatRequest.class); + public enum AnthropicChatModel implements ModelDescription { + + /** + * anthropic.claude-instant-v1 + */ + CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), + /** + * anthropic.claude-v2 + */ + CLAUDE_V2("anthropic.claude-v2"), + /** + * anthropic.claude-v2:1 + */ + CLAUDE_V21("anthropic.claude-v2:1"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; } - if (prompt.getOptions() != null) { - AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), - ChatOptions.class, AnthropicChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class); + AnthropicChatModel(String value) { + this.id = value; } - return request; - } + @Override + public String getModelName() { + return this.id; + } - @Override - public ChatOptions getDefaultOptions() { - return AnthropicChatOptions.fromOptions(this.defaultOptions); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java deleted file mode 100644 index cd85fb7a5e..0000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.anthropic.api; - -import java.time.Duration; -import java.util.List; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.model.ModelDescription; -import org.springframework.util.Assert; - -/** - * @author Christian Tzolov - * @author Wei Jiang - * @since 0.8.0 - */ -// @formatter:off -public class AnthropicChatBedrockApi extends - AbstractBedrockApi { - - public static final String PROMPT_TEMPLATE = "\n\nHuman:%s\n\nAssistant:"; - - /** - * Default version of the Anthropic chat model. - */ - public static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; - - - /** - * Create a new AnthropicChatBedrockApi instance using the default credentials provider chain, the default object. - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param region The AWS region to use. - */ - public AnthropicChatBedrockApi(String modelId, String region) { - super(modelId, region); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the default credentials provider chain, the default object. - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param region The AWS region to use. - * @param timeout The timeout to use. - */ - public AnthropicChatBedrockApi(String modelId, String region, Duration timeout) { - super(modelId, region, timeout); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - */ - public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { - super(modelId, credentialsProvider, region, objectMapper); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - // https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java - - // https://docs.anthropic.com/claude/reference/complete_post - - // https://docs.aws.amazon.com/bedrock/latest/userguide/br-product-ids.html - - // Anthropic Claude models: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html - - /** - * AnthropicChatRequest encapsulates the request parameters for the Anthropic chat model. - * https://docs.anthropic.com/claude/reference/complete_post - * - * @param prompt The prompt to use for the chat. - * @param temperature (default 0.5) The temperature to use for the chat. You should either alter temperature or - * top_p, but not both. - * @param maxTokensToSample (default 200) Specify the maximum number of tokens to use in the generated response. - * Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum - * number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. - * @param topK (default 250) Specify the number of token choices the model uses to generate the next token. - * @param topP (default 1) Nucleus sampling to specify the cumulative probability of the next token in range [0,1]. - * In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in - * decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You - * should either alter temperature or top_p, but not both. - * @param stopSequences (defaults to "\n\nHuman:") Configure up to four sequences that the model recognizes. After a - * stop sequence, the model stops generating further tokens. The returned text doesn't contain the stop sequence. - * @param anthropicVersion The version of the model to use. The default value is bedrock-2023-05-31. - */ - @JsonInclude(Include.NON_NULL) - public record AnthropicChatRequest( - @JsonProperty("prompt") String prompt, - @JsonProperty("temperature") Float temperature, - @JsonProperty("max_tokens_to_sample") Integer maxTokensToSample, - @JsonProperty("top_k") Integer topK, - @JsonProperty("top_p") Float topP, - @JsonProperty("stop_sequences") List stopSequences, - @JsonProperty("anthropic_version") String anthropicVersion) { - - public static Builder builder(String prompt) { - return new Builder(prompt); - } - - public static class Builder { - private final String prompt; - private Float temperature;// = 0.7f; - private Integer maxTokensToSample;// = 500; - private Integer topK;// = 10; - private Float topP; - private List stopSequences; - private String anthropicVersion; - - private Builder(String prompt) { - this.prompt = prompt; - } - - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder withMaxTokensToSample(Integer maxTokensToSample) { - this.maxTokensToSample = maxTokensToSample; - return this; - } - - public Builder withTopK(Integer topK) { - this.topK = topK; - return this; - } - - public Builder withTopP(Float tpoP) { - this.topP = tpoP; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder withAnthropicVersion(String anthropicVersion) { - this.anthropicVersion = anthropicVersion; - return this; - } - - public AnthropicChatRequest build() { - return new AnthropicChatRequest( - prompt, - temperature, - maxTokensToSample, - topK, - topP, - stopSequences, - anthropicVersion - ); - } - } - } - - /** - * AnthropicChatResponse encapsulates the response parameters for the Anthropic chat model. - * - * @param completion The generated text. - * @param stopReason The reason the model stopped generating text. - * @param stop The stop sequence that caused the model to stop generating text. - * @param amazonBedrockInvocationMetrics Metrics about the model invocation. - */ - @JsonInclude(Include.NON_NULL) - public record AnthropicChatResponse( - @JsonProperty("type") String type, - @JsonProperty("completion") String completion, - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop") String stop, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - } - - /** - * Anthropic models version. - */ - public enum AnthropicChatModel implements ModelDescription { - /** - * anthropic.claude-instant-v1 - */ - CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), - /** - * anthropic.claude-v2 - */ - CLAUDE_V2("anthropic.claude-v2"), - /** - * anthropic.claude-v2:1 - */ - CLAUDE_V21("anthropic.claude-v2:1"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - AnthropicChatModel(String value) { - this.id = value; - } - - @Override - public String getModelName() { - return this.id; - } - } - - @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); - } - - @Override - public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocationStream(anthropicRequest, AnthropicChatResponse.class); - } - -} -// @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java index b4995683a3..208925f6d6 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java @@ -15,19 +15,31 @@ */ package org.springframework.ai.bedrock.anthropic3; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; /** + * Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html + * * @author Ben Middleton + * @author Wei Jiang * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class Anthropic3ChatOptions implements ChatOptions { +public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions { // @formatter:off /** @@ -63,9 +75,30 @@ public class Anthropic3ChatOptions implements ChatOptions { private @JsonProperty("stop_sequences") List stopSequences; /** - * The version of the generative to use. The default value is bedrock-2023-05-31. + * Tool Function Callbacks to register with the ChatModel. For Prompt + * Options the functionCallbacks are automatically enabled for the duration of the + * prompt execution. For Default Options the functionCallbacks are registered but + * disabled by default. Use the enableFunctions to set the functions from the registry + * to be used by the ChatModel chat completion requests. */ - private @JsonProperty("anthropic_version") String anthropicVersion; + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. Functions with those names must exist in the + * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions + * are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat + * completion requests. This could impact the token count and the billing. If the + * functions is set in a prompt options, then the enabled functions are only active + * for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); // @formatter:on public static Builder builder() { @@ -101,8 +134,20 @@ public Builder withStopSequences(List stopSequences) { return this; } - public Builder withAnthropicVersion(String anthropicVersion) { - this.options.setAnthropicVersion(anthropicVersion); + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); return this; } @@ -155,12 +200,26 @@ public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } - public String getAnthropicVersion() { - return this.anthropicVersion; + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = functionCallbacks; } - public void setAnthropicVersion(String anthropicVersion) { - this.anthropicVersion = anthropicVersion; + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + Assert.notNull(functions, "Function must not be null"); + this.functions = functions; } public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { @@ -169,7 +228,8 @@ public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOption .withTopK(fromOptions.getTopK()) .withTopP(fromOptions.getTopP()) .withStopSequences(fromOptions.getStopSequences()) - .withAnthropicVersion(fromOptions.getAnthropicVersion()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) .build(); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index 0b42d4266e..70b8dbdcac 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -15,175 +15,278 @@ */ package org.springframework.ai.bedrock.anthropic3; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.StopReason; +import software.amazon.awssdk.services.bedrockruntime.model.Tool; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultStatus; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; +import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; + import java.util.ArrayList; -import java.util.Base64; +import java.util.HashSet; import java.util.List; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; +import java.util.Set; -import reactor.core.publisher.Flux; - -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelDescription; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; /** - * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat - * generative. + * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic3 chat + * generative model. * * @author Ben Middleton * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ -public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel { +public class BedrockAnthropic3ChatModel + extends AbstractFunctionCallSupport + implements ChatModel, StreamingChatModel { + + private final String modelId; - private final Anthropic3ChatBedrockApi anthropicChatApi; + private final BedrockConverseApi converseApi; private final Anthropic3ChatOptions defaultOptions; - public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) { - this(chatApi, - Anthropic3ChatOptions.builder() - .withTemperature(0.8f) - .withMaxTokens(500) - .withTopK(10) - .withAnthropicVersion(Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) - .build()); + public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi) { + this(converseApi, + Anthropic3ChatOptions.builder().withTemperature(0.8f).withMaxTokens(500).withTopK(10).build()); } - public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) { - this.anthropicChatApi = chatApi; + public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3ChatOptions options) { + this(Anthropic3ChatModel.CLAUDE_V3_SONNET.id(), converseApi, options); + } + + public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options) { + this(modelId, converseApi, options, null); + } + + public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options, + FunctionCallbackContext functionCallbackContext) { + super(functionCallbackContext); + + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(options, "Anthropic3ChatOptions must not be null."); + + this.modelId = modelId; + this.converseApi = converseApi; this.defaultOptions = options; } @Override public ChatResponse call(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - AnthropicChatRequest request = createRequest(prompt); - - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + var request = createBedrockConverseRequest(prompt); - return new ChatResponse(List.of(new Generation(response.content().get(0).text()))); + return this.callWithFunctionSupport(request); } @Override public Flux stream(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - AnthropicChatRequest request = createRequest(prompt); + var request = createBedrockConverseRequest(prompt); - Flux fluxResponse = this.anthropicChatApi - .chatCompletionStream(request); + return converseApi.converseStream(request); + } - AtomicReference inputTokens = new AtomicReference<>(0); - return fluxResponse.map(response -> { - if (response.type() == StreamingType.MESSAGE_START) { - inputTokens.set(response.message().usage().inputTokens()); - } - String content = response.type() == StreamingType.CONTENT_BLOCK_DELTA ? response.delta().text() : ""; + private BedrockConverseRequest createBedrockConverseRequest(Prompt prompt) { + var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions); - var generation = new Generation(content); + ToolConfiguration toolConfiguration = createToolConfiguration(prompt); - if (response.type() == StreamingType.MESSAGE_DELTA) { - generation = generation.withGenerationMetadata(ChatGenerationMetadata - .from(response.delta().stopReason(), new Anthropic3ChatBedrockApi.AnthropicUsage(inputTokens.get(), - response.usage().outputTokens()))); - } - - return new ChatResponse(List.of(generation)); - }); + return BedrockConverseRequest.from(request).withToolConfiguration(toolConfiguration).build(); } - /** - * Accessible for testing. - */ - AnthropicChatRequest createRequest(Prompt prompt) { - - AnthropicChatRequest request = AnthropicChatRequest.builder(toAnthropicMessages(prompt)) - .withSystem(toAnthropicSystemContext(prompt)) - .build(); + private ToolConfiguration createToolConfiguration(Prompt prompt) { + Set functionsForThisRequest = new HashSet<>(); if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, AnthropicChatRequest.class); + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + functionsForThisRequest.addAll(promptEnabledFunctions); } if (prompt.getOptions() != null) { Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, Anthropic3ChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class); + + Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, + IS_RUNTIME_CALL); + functionsForThisRequest.addAll(defaultEnabledFunctions); } - return request; + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + return ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build(); + } + + return null; } - /** - * Extracts system context from prompt. - * @param prompt The prompt. - * @return The system context. - */ - private String toAnthropicSystemContext(Prompt prompt) { + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var description = functionCallback.getDescription(); + var name = functionCallback.getName(); + String inputSchema = functionCallback.getInputTypeSchema(); - return prompt.getInstructions() - .stream() - .filter(m -> m.getMessageType() == MessageType.SYSTEM) - .map(Message::getContent) - .collect(Collectors.joining(System.lineSeparator())); + return Tool.builder() + .toolSpec(ToolSpecification.builder() + .name(name) + .description(description) + .inputSchema(ToolInputSchema.builder() + .json(BedrockConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema))) + .build()) + .build()) + .build(); + }).toList(); } - /** - * Extracts list of messages from prompt. - * @param prompt The prompt. - * @return The list of {@link ChatCompletionMessage}. - */ - private List toAnthropicMessages(Prompt prompt) { + @Override + public ChatOptions getDefaultOptions() { + return Anthropic3ChatOptions.fromOptions(this.defaultOptions); + } - return prompt.getInstructions() + @Override + protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest, + Message responseMessage, List conversationHistory) { + List toolToUseList = responseMessage.content() .stream() - .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) - .map(message -> { - List contents = new ArrayList<>(List.of(new MediaContent(message.getContent()))); - if (!CollectionUtils.isEmpty(message.getMedia())) { - List mediaContent = message.getMedia() - .stream() - .map(media -> new MediaContent(media.getMimeType().toString(), - this.fromMediaData(media.getData()))) - .toList(); - contents.addAll(mediaContent); - } - return new ChatCompletionMessage(contents, Role.valueOf(message.getMessageType().name())); - }) + .filter(content -> content.type() == Type.TOOL_USE) + .map(content -> content.toolUse()) .toList(); + + List toolResults = new ArrayList<>(); + + for (ToolUseBlock toolToUse : toolToUseList) { + var functionCallId = toolToUse.toolUseId(); + var functionName = toolToUse.name(); + var functionArguments = toolToUse.input().unwrap(); + + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName) + .call(ModelOptionsUtils.toJsonString(functionArguments)); + + toolResults.add(ToolResultBlock.builder() + .toolUseId(functionCallId) + .status(ToolResultStatus.SUCCESS) + .content(ToolResultContentBlock.builder().text(functionResponse).build()) + .build()); + } + + // Add the function response to the conversation. + Message toolResultMessage = Message.builder() + .content(toolResults.stream().map(toolResult -> ContentBlock.fromToolResult(toolResult)).toList()) + .role(ConversationRole.USER) + .build(); + conversationHistory.add(toolResultMessage); + + // Recursively call chatCompletionWithTools until the model doesn't call a + // functions anymore. + return BedrockConverseRequest.from(previousRequest).withMessages(conversationHistory).build(); + } + + @Override + protected List doGetUserMessages(BedrockConverseRequest request) { + return request.messages(); + } + + @Override + protected Message doGetToolResponseMessage(ChatResponse response) { + Generation result = response.getResult(); + + var metadata = (BedrockConverseChatGenerationMetadata) result.getMetadata(); + + return metadata.getMessage(); + } + + @Override + protected ChatResponse doChatCompletion(BedrockConverseRequest request) { + return converseApi.converse(request); + } + + @Override + protected Flux doChatCompletionStream(BedrockConverseRequest request) { + throw new UnsupportedOperationException("Streaming function calling is not supported."); } - private String fromMediaData(Object mediaData) { - if (mediaData instanceof byte[] bytes) { - return Base64.getEncoder().encodeToString(bytes); + @Override + protected boolean isToolFunctionCall(ChatResponse response) { + Generation result = response.getResult(); + if (result == null) { + return false; + } + + return StopReason.fromValue(result.getMetadata().getFinishReason()) == StopReason.TOOL_USE; + } + + /** + * Anthropic3 models version. + */ + public enum Anthropic3ChatModel implements ModelDescription { + + /** + * anthropic.claude-3-sonnet-20240229-v1:0 + */ + CLAUDE_V3_SONNET("anthropic.claude-3-sonnet-20240229-v1:0"), + /** + * anthropic.claude-3-haiku-20240307-v1:0 + */ + CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), + /** + * anthropic.claude-3-opus-20240229-v1:0 + */ + CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"), + /** + * anthropic.claude-3-5-sonnet-20240620-v1:0 + */ + CLAUDE_V3_5_SONNET("anthropic.claude-3-5-sonnet-20240620-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; } - else if (mediaData instanceof String text) { - return text; + + Anthropic3ChatModel(String value) { + this.id = value; } - else { - throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName()); + + @Override + public String getModelName() { + return this.id; } - } - @Override - public ChatOptions getDefaultOptions() { - return Anthropic3ChatOptions.fromOptions(this.defaultOptions); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java deleted file mode 100644 index 8b5b29ed1e..0000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ /dev/null @@ -1,500 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.anthropic3.api; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse; -import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.model.ModelDescription; -import org.springframework.util.Assert; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; - -/** - * Based on Bedrock's Anthropic - * Claude Messages API. - * - * It is meant to replace the previous Chat API, which is now deprecated. - * - * @author Ben Middleton - * @author Christian Tzolov - * @author Wei Jiang - * @since 1.0.0 - */ -// @formatter:off -public class Anthropic3ChatBedrockApi extends - AbstractBedrockApi { - - /** - * Default version of the Anthropic chat model. - */ - public static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; - - /** - * Create a new AnthropicChatBedrockApi instance using the default credentials provider chain, the default object. - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param region The AWS region to use. - */ - public Anthropic3ChatBedrockApi(String modelId, String region) { - super(modelId, region); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the default credentials provider chain, the default object. - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param region The AWS region to use. - * @param timeout The timeout to use. - */ - public Anthropic3ChatBedrockApi(String modelId, String region, Duration timeout) { - super(modelId, region, timeout); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - */ - public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { - super(modelId, credentialsProvider, region, objectMapper); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - // https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java - - // https://docs.anthropic.com/claude/reference/complete_post - - // https://docs.aws.amazon.com/bedrock/latest/userguide/br-product-ids.html - - // Anthropic Claude models: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html - - /** - * AnthropicChatRequest encapsulates the request parameters for the Anthropic messages model. - * https://docs.anthropic.com/claude/reference/messages_post - * - * @param messages A list of messages comprising the conversation so far. - * @param system A system prompt, providing context and instructions to Claude, such as specifying a particular goal - * or role. - * @param temperature (default 0.5) The temperature to use for the chat. You should either alter temperature or - * top_p, but not both. - * @param maxTokens (default 200) Specify the maximum number of tokens to use in the generated response. - * Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum - * number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. - * @param topK (default 250) Specify the number of token choices the model uses to generate the next token. - * @param topP (default 1) Nucleus sampling to specify the cumulative probability of the next token in range [0,1]. - * In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in - * decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You - * should either alter temperature or top_p, but not both. - * @param stopSequences (defaults to "\n\nHuman:") Configure up to four sequences that the model recognizes. After a - * stop sequence, the model stops generating further tokens. The returned text doesn't contain the stop sequence. - * @param anthropicVersion The version of the model to use. The default value is bedrock-2023-05-31. - */ - @JsonInclude(Include.NON_NULL) - public record AnthropicChatRequest( - @JsonProperty("messages") List messages, - @JsonProperty("system") String system, - @JsonProperty("temperature") Float temperature, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("top_k") Integer topK, - @JsonProperty("top_p") Float topP, - @JsonProperty("stop_sequences") List stopSequences, - @JsonProperty("anthropic_version") String anthropicVersion) { - - public static Builder builder(List messages) { - return new Builder(messages); - } - - public static class Builder { - private final List messages; - private String system; - private Float temperature;// = 0.7f; - private Integer maxTokens;// = 500; - private Integer topK;// = 10; - private Float topP; - private List stopSequences; - private String anthropicVersion; - - private Builder(List messages) { - this.messages = messages; - } - - public Builder withSystem(String system) { - this.system = system; - return this; - } - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public Builder withTopK(Integer topK) { - this.topK = topK; - return this; - } - - public Builder withTopP(Float tpoP) { - this.topP = tpoP; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder withAnthropicVersion(String anthropicVersion) { - this.anthropicVersion = anthropicVersion; - return this; - } - - public AnthropicChatRequest build() { - return new AnthropicChatRequest( - messages, - system, - temperature, - maxTokens, - topK, - topP, - stopSequences, - anthropicVersion - ); - } - } - } - - /** - * @param type the content type can be "text" or "image". - * @param source The source of the media content. Applicable for "image" types only. - * @param text The text of the message. Applicable for "text" types only. - * @param index The index of the content block. Applicable only for streaming - * responses. - */ - @JsonInclude(Include.NON_NULL) - public record MediaContent( // @formatter:off - @JsonProperty("type") Type type, - @JsonProperty("source") Source source, - @JsonProperty("text") String text, - @JsonProperty("index") Integer index // applicable only for streaming responses. - ) { - // @formatter:on - - public MediaContent(String mediaType, String data) { - this(new Source(mediaType, data)); - } - - public MediaContent(Source source) { - this(Type.IMAGE, source, null, null); - } - - public MediaContent(String text) { - this(Type.TEXT, null, text, null); - } - - /** - * The type of this message. - */ - public enum Type { - - /** - * Text message. - */ - @JsonProperty("text") - TEXT, - /** - * Image message. - */ - @JsonProperty("image") - IMAGE - - } - - /** - * The source of the media content. (Applicable for "image" types only) - * - * @param type The type of the media content. Only "base64" is supported at the - * moment. - * @param mediaType The media type of the content. For example, "image/png" or - * "image/jpeg". - * @param data The base64-encoded data of the content. - */ - @JsonInclude(Include.NON_NULL) - public record Source( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("media_type") String mediaType, - @JsonProperty("data") String data) { - // @formatter:on - - public Source(String mediaType, String data) { - this("base64", mediaType, data); - } - } - } - - /** - * Message comprising the conversation. - * - * @param content The contents of the message. - * @param role The role of the messages author. Could be one of the {@link Role} - * types. - */ - @JsonInclude(Include.NON_NULL) - public record ChatCompletionMessage(@JsonProperty("content") List content, - @JsonProperty("role") Role role) { - - /** - * The role of the author of this message. - */ - public enum Role { - - /** - * User message. - */ - @JsonProperty("user") - USER, - /** - * Assistant message. - */ - @JsonProperty("assistant") - ASSISTANT - - } - } - - /** - * Encapsulates the metrics about the model invocation. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#model-parameters-anthropic-claude-messages-request-response - * - * @param inputTokens The number of tokens in the input prompt. - * @param outputTokens The number of tokens in the generated text. - */ - @JsonInclude(Include.NON_NULL) - public record AnthropicUsage(@JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { - } - - /** - * AnthropicChatResponse encapsulates the response parameters for the Anthropic - * messages model. - * - * @param id The unique response identifier. - * @param model The ID for the Anthropic Claude model that made the request. - * @param type The type of the response. - * @param role The role of the response. - * @param content The list of generated text. - * @param stopReason The reason the model stopped generating text: end_turn – The - * model reached a natural stopping point. max_tokens – The generated text exceeded - * the value of the max_tokens input field or exceeded the maximum number of tokens - * that the model supports. stop_sequence – The model generated one of the stop - * sequences that you specified in the stop_sequences input field. - * @param stopSequence The stop sequence that caused the model to stop generating - * text. - * @param usage Metrics about the model invocation. - */ - @JsonInclude(Include.NON_NULL) - public record AnthropicChatResponse(// formatter:off - @JsonProperty("id") String id, @JsonProperty("model") String model, @JsonProperty("type") String type, - @JsonProperty("role") String role, @JsonProperty("content") List content, - @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence, - @JsonProperty("usage") AnthropicUsage usage, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { // formatter:on - } - - /** - * AnthropicChatStreamingResponse encapsulates the streaming response parameters for - * the Anthropic messages model. - * https://docs.anthropic.com/claude/reference/messages-streaming - * - * @param type The streaming type. - * @param message The message details that made the request. - * @param index The delta index. - * @param contentBlock The generated text. - * @param delta The delta. - * @param usage The usage data. - */ - @JsonInclude(Include.NON_NULL) - public record AnthropicChatStreamingResponse(// formatter:off - @JsonProperty("type") StreamingType type, @JsonProperty("message") AnthropicChatResponse message, - @JsonProperty("index") Integer index, @JsonProperty("content_block") MediaContent contentBlock, - @JsonProperty("delta") Delta delta, @JsonProperty("usage") AnthropicUsage usage, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { // formatter:on - - /** - * The streaming type of this message. - */ - public enum StreamingType { - - /** - * Message start. - */ - @JsonProperty("message_start") - MESSAGE_START, - /** - * Content block start. - */ - @JsonProperty("content_block_start") - CONTENT_BLOCK_START, - /** - * Ping. - */ - @JsonProperty("ping") - PING, - /** - * Content block delta. - */ - @JsonProperty("content_block_delta") - CONTENT_BLOCK_DELTA, - /** - * Content block stop. - */ - @JsonProperty("content_block_stop") - CONTENT_BLOCK_STOP, - /** - * Message delta. - */ - @JsonProperty("message_delta") - MESSAGE_DELTA, - /** - * Message stop. - */ - @JsonProperty("message_stop") - MESSAGE_STOP - - } - - /** - * Encapsulates a delta. - * https://docs.anthropic.com/claude/reference/messages-streaming * - * - * @param type The type of the message. - * @param text The text message. - * @param stopReason The stop reason. - * @param stopSequence The stop sequence. - */ - @JsonInclude(Include.NON_NULL) - public record Delta(@JsonProperty("type") String type, @JsonProperty("text") String text, - @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence) { - } - } - - /** - * Anthropic models version. - */ - public enum AnthropicChatModel implements ModelDescription { - - /** - * anthropic.claude-instant-v1 - */ - CLAUDE_INSTANT_V1("anthropic.claude-instant-v1"), - /** - * anthropic.claude-v2 - */ - CLAUDE_V2("anthropic.claude-v2"), - /** - * anthropic.claude-v2:1 - */ - CLAUDE_V21("anthropic.claude-v2:1"), - /** - * anthropic.claude-3-sonnet-20240229-v1:0 - */ - CLAUDE_V3_SONNET("anthropic.claude-3-sonnet-20240229-v1:0"), - /** - * anthropic.claude-3-haiku-20240307-v1:0 - */ - CLAUDE_V3_HAIKU("anthropic.claude-3-haiku-20240307-v1:0"), - /** - * anthropic.claude-3-opus-20240229-v1:0 - */ - CLAUDE_V3_OPUS("anthropic.claude-3-opus-20240229-v1:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - AnthropicChatModel(String value) { - this.id = value; - } - - @Override - public String getModelName() { - return this.id; - } - - } - - @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); - } - - @Override - public Flux chatCompletionStream(AnthropicChatRequest anthropicRequest) { - Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); - return this.internalInvocationStream(anthropicRequest, AnthropicChatStreamingResponse.class); - } - -} -// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java index 7db24b3b8c..4317437494 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java @@ -16,20 +16,14 @@ package org.springframework.ai.bedrock.aot; import org.springframework.ai.bedrock.anthropic.AnthropicChatOptions; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.Anthropic3ChatOptions; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingOptions; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingOptions; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; @@ -53,11 +47,7 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); for (var tr : findJsonAnnotatedClassesInPackage(AbstractBedrockApi.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(Ai21Jurassic2ChatBedrockApi.class)) - hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(CohereChatBedrockApi.class)) - hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereChatOptions.class)) hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(CohereEmbeddingBedrockApi.class)) @@ -65,13 +55,9 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class)) - hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class)) - hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanChatOptions.class)) hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class)) @@ -79,13 +65,9 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatBedrockApi.class)) - hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatOptions.class)) hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(Anthropic3ChatBedrockApi.class)) - hints.reflection().registerType(tr, mcs); for (var tr : findJsonAnnotatedClassesInPackage(Anthropic3ChatOptions.class)) hints.reflection().registerType(tr, mcs); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 24a383adac..823c3f1bd0 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -70,7 +70,6 @@ public abstract class AbstractBedrockApi { private final String modelId; private final ObjectMapper objectMapper; - private final Region region; private final BedrockRuntimeClient client; private final BedrockRuntimeAsyncClient clientStreaming; @@ -136,29 +135,36 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv */ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, ObjectMapper objectMapper, Duration timeout) { + this(modelId, BedrockRuntimeClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(c -> c.apiCallTimeout(timeout)) + .build(), BedrockRuntimeAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(c -> c.apiCallTimeout(timeout)) + .build(), objectMapper); + } + /** + * Create a new AbstractBedrockApi instance using the provided AWS Bedrock clients, region and object mapper. + * + * @param modelId The model id to use. + * @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance. + * @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + */ + public AbstractBedrockApi(String modelId, BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, + ObjectMapper objectMapper) { Assert.hasText(modelId, "Model id must not be empty"); - Assert.notNull(credentialsProvider, "Credentials provider must not be null"); - Assert.notNull(region, "Region must not be empty"); + Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null"); + Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null"); Assert.notNull(objectMapper, "Object mapper must not be null"); - Assert.notNull(timeout, "Timeout must not be null"); this.modelId = modelId; + this.client = bedrockRuntimeClient; + this.clientStreaming = bedrockRuntimeAsyncClient; this.objectMapper = objectMapper; - this.region = region; - - - this.client = BedrockRuntimeClient.builder() - .region(this.region) - .credentialsProvider(credentialsProvider) - .overrideConfiguration(c -> c.apiCallTimeout(timeout)) - .build(); - - this.clientStreaming = BedrockRuntimeAsyncClient.builder() - .region(this.region) - .credentialsProvider(credentialsProvider) - .overrideConfiguration(c -> c.apiCallTimeout(timeout)) - .build(); } /** @@ -168,13 +174,6 @@ public String getModelId() { return this.modelId; } - /** - * @return The AWS region. - */ - public Region getRegion() { - return this.region; - } - /** * Encapsulates the metrics about the model invocation. * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java new file mode 100644 index 0000000000..dc1e5404e9 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java @@ -0,0 +1,370 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// @formatter:off +package org.springframework.ai.bedrock.api; + +import java.time.Duration; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.Sinks.EmitFailureHandler; +import reactor.core.publisher.Sinks.EmitResult; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; + +/** + * Amazon Bedrock Converse API, It provides the basic functionality to invoke the Bedrock + * AI model and receive the response for streaming and non-streaming requests. + * The Converse API doesn't support any embedding models (such as Titan Embeddings G1 - Text) + * or image generation models (such as Stability AI). + * + *

+ * https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + *

+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + *

+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + *

+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html + *

+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockConverseApi { + + private static final Logger logger = LoggerFactory.getLogger(BedrockConverseApi.class); + + private final BedrockRuntimeClient bedrockRuntimeClient; + + private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; + + private final RetryTemplate retryTemplate; + + /** + * Create a new BedrockConverseApi instance using default credentials provider. + * + * @param region The AWS region to use. + */ + public BedrockConverseApi(String region) { + this(ProfileCredentialsProvider.builder().build(), region, Duration.ofMinutes(5)); + } + + /** + * Create a new BedrockConverseApi instance using default credentials provider. + * + * @param region The AWS region to use. + * @param timeout The timeout to use. + */ + public BedrockConverseApi(String region, Duration timeout) { + this(ProfileCredentialsProvider.builder().build(), region, timeout); + } + + /** + * Create a new BedrockConverseApi instance using the provided credentials provider, + * region. + * + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + */ + public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, String region) { + this(credentialsProvider, region, Duration.ofMinutes(5)); + } + + /** + * Create a new BedrockConverseApi instance using the provided credentials provider, + * region. + * + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param timeout Configure the amount of time to allow the client to complete the + * execution of an API call. This timeout covers the entire client execution except + * for marshalling. This includes request handler execution, all HTTP requests + * including retries, unmarshalling, etc. This value should always be positive, if + * present. + */ + public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, String region, Duration timeout) { + this(credentialsProvider, Region.of(region), timeout); + } + + /** + * Create a new BedrockConverseApi instance using the provided credentials provider, + * region. + * + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param timeout Configure the amount of time to allow the client to complete the + * execution of an API call. This timeout covers the entire client execution except + * for marshalling. This includes request handler execution, all HTTP requests + * including retries, unmarshalling, etc. This value should always be positive, if + * present. + */ + public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout) { + this(credentialsProvider, region, timeout, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Create a new BedrockConverseApi instance using the provided credentials provider, + * region + * + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param timeout Configure the amount of time to allow the client to complete the + * execution of an API call. This timeout covers the entire client execution except + * for marshalling. This includes request handler execution, all HTTP requests + * including retries, unmarshalling, etc. This value should always be positive, if + * present. + * @param retryTemplate The retry template used to retry the Amazon Bedrock Converse + * API calls. + */ + public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout, + RetryTemplate retryTemplate) { + this(BedrockRuntimeClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(c -> c.apiCallTimeout(timeout)) + .build(), BedrockRuntimeAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(c -> c.apiCallTimeout(timeout)) + .build(), retryTemplate); + } + + /** + * Create a new BedrockConverseApi instance using the provided AWS Bedrock clients and the RetryTemplate. + * + * @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance. + * @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance. + * @param retryTemplate The retry template used to retry the Amazon Bedrock Converse + * API calls. + */ + public BedrockConverseApi(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, + RetryTemplate retryTemplate) { + Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null"); + Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + + this.bedrockRuntimeClient = bedrockRuntimeClient; + this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; + this.retryTemplate = retryTemplate; + } + + /** + * Invoke the model and return the response. + * + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse + * @param bedrockConverseRequest Model invocation request. + * @return The model invocation response. + */ + public ChatResponse converse(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null"); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(bedrockConverseRequest); + + ConverseResponse converseResponse = converse(converseRequest); + + return BedrockConverseApiUtils.convertConverseResponse(converseResponse); + } + + /** + * Invoke the model and return the response. + * + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse + * @param converseRequest Model invocation request. + * @return The model invocation response. + */ + public ConverseResponse converse(ConverseRequest converseRequest) { + Assert.notNull(converseRequest, "'converseRequest' must not be null"); + + return this.retryTemplate.execute(ctx -> { + return bedrockRuntimeClient.converse(converseRequest); + }); + } + + /** + * Invoke the model and return the response stream. + * + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream + * @param bedrockConverseRequest Model invocation request. + * @return The model invocation response stream. + */ + public Flux converseStream(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null"); + + ConverseStreamRequest converseStreamRequest = BedrockConverseApiUtils + .createConverseStreamRequest(bedrockConverseRequest); + + return converseStream(converseStreamRequest) + .map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output)); + + } + + /** + * Invoke the model and return the response stream. + * + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream + * @param converseStreamRequest Model invocation request. + * @return The model invocation response stream. + */ + public Flux converseStream(ConverseStreamRequest converseStreamRequest) { + Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null"); + + return this.retryTemplate.execute(ctx -> { + Sinks.Many eventSink = Sinks.many().multicast().onBackpressureBuffer(); + + ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder() + .onDefault((output) -> { + logger.debug("Received converse stream output:{}", output); + eventSink.tryEmitNext(output); + }) + .build(); + + ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder() + .onEventStream(stream -> stream.subscribe((e) -> e.accept(visitor))) + .onComplete(() -> { + EmitResult emitResult = eventSink.tryEmitComplete(); + + while (!emitResult.isSuccess()) { + logger.debug("Emitting complete:{}", emitResult); + emitResult = eventSink.tryEmitComplete(); + } + + eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3))); + logger.debug("Completed streaming response."); + }) + .onError((error) -> { + logger.error("Error handling Bedrock converse stream response", error); + eventSink.tryEmitError(error); + }) + .build(); + + bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler); + + return eventSink.asFlux(); + }); + } + + /** + * BedrockConverseRequest encapsulates the request parameters for the Amazon Bedrock + * Converse Api. + * + * @param modelId The Amazon Bedrock Model Id. + * @param messages The messages that you want to send to the model. + * @param systemMessages The system prompt to pass to the model. + * @param additionalModelRequestFields Additional inference parameters that the model + * supports, beyond the base set of inference parameters that Converse supports in the + * inferenceConfig field. + * @param toolConfiguration Configuration information for the tools that the model can + * use when generating a response. + */ + public record BedrockConverseRequest(String modelId, List messages, + List systemMessages, Document additionalModelRequestFields, + ToolConfiguration toolConfiguration) { + + public BedrockConverseRequest(String modelId, List messages, List systemMessages, + Document additionalModelRequestFields) { + this(modelId, messages, systemMessages, additionalModelRequestFields, null); + } + + public static Builder from(BedrockConverseRequest request) { + return new Builder(request); + } + + public static class Builder { + + private String modelId; + + private List messages; + + private List systemMessages; + + private Document additionalModelRequestFields; + + private ToolConfiguration toolConfiguration; + + private Builder(BedrockConverseRequest request) { + this.modelId = request.modelId(); + this.messages = request.messages(); + this.systemMessages = request.systemMessages(); + this.additionalModelRequestFields = request.additionalModelRequestFields(); + this.toolConfiguration = request.toolConfiguration(); + } + + public Builder withModelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder withMessages(List messages) { + this.messages = messages; + return this; + } + + public Builder withSystemMessages(List systemMessages) { + this.systemMessages = systemMessages; + return this; + } + + public Builder withAdditionalModelRequestFields(Document additionalModelRequestFields) { + this.additionalModelRequestFields = additionalModelRequestFields; + return this; + } + + public Builder withToolConfiguration(ToolConfiguration toolConfiguration) { + this.toolConfiguration = toolConfiguration; + return this; + } + + public BedrockConverseRequest build() { + return new BedrockConverseRequest(modelId, messages, systemMessages, additionalModelRequestFields, + toolConfiguration); + } + + } + + } + +} +//@formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java new file mode 100644 index 0000000000..b7df79ec7d --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java @@ -0,0 +1,368 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +//@formatter:off +package org.springframework.ai.bedrock.api; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata; +import org.springframework.ai.bedrock.BedrockChatResponseMetadata; +import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptions; +import org.springframework.util.Assert; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ImageSource; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; +import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type; + +/** + * Amazon Bedrock Converse API utils. + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockConverseApiUtils { + private static final ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @param defaultOptions The default options needs to convert. + * @return Amazon Bedrock Converse encapsulates request. + */ + public static BedrockConverseRequest createBedrockConverseRequest(String modelId, Prompt prompt, + ChatOptions defaultOptions) { + Assert.notNull(modelId, "'modelId' must not be null."); + Assert.notNull(prompt, "'prompt' must not be null."); + + List messages = getInstructionsMessages(prompt.getInstructions()); + + List systemMessages = getPromptSystemContentBlocks(prompt); + + Document additionalModelRequestFields = getChatOptionsAdditionalModelRequestFields(defaultOptions, + prompt.getOptions()); + + return new BedrockConverseRequest(modelId, messages, systemMessages, additionalModelRequestFields); + } + + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @return Amazon Bedrock Converse request. + */ + public static ConverseRequest createConverseRequest(String modelId, Prompt prompt) { + return createConverseRequest(modelId, prompt, null); + } + + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @param defaultOptions The default options needs to convert. + * @return Amazon Bedrock Converse request. + */ + public static ConverseRequest createConverseRequest(String modelId, Prompt prompt, ChatOptions defaultOptions) { + BedrockConverseRequest bedrockConverseRequest = createBedrockConverseRequest(modelId, prompt, defaultOptions); + + return createConverseRequest(bedrockConverseRequest); + } + + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax + * + * @param bedrockConverseRequest The Amazon Bedrock Converse encapsulates request. + * @return Amazon Bedrock Converse request. + */ + public static ConverseRequest createConverseRequest(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null."); + + return ConverseRequest.builder() + .modelId(bedrockConverseRequest.modelId()) + .messages(bedrockConverseRequest.messages()) + .system(bedrockConverseRequest.systemMessages()) + .additionalModelRequestFields(bedrockConverseRequest.additionalModelRequestFields()) + .toolConfig(bedrockConverseRequest.toolConfiguration()) + .build(); + } + + /** + * Convert {@link Prompt} to {@link ConverseStreamRequest} with model id and options. + * It will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @param defaultOptions The default options needs to convert. + * @return Amazon Bedrock Converse stream request. + */ + public static ConverseStreamRequest createConverseStreamRequest(String modelId, Prompt prompt) { + return createConverseStreamRequest(modelId, prompt, null); + } + + /** + * Convert {@link Prompt} to {@link ConverseStreamRequest} with model id and options. + * It will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @param defaultOptions The default options needs to convert. + * @return Amazon Bedrock Converse stream request. + */ + public static ConverseStreamRequest createConverseStreamRequest(String modelId, Prompt prompt, + ChatOptions defaultOptions) { + BedrockConverseRequest bedrockConverseRequest = createBedrockConverseRequest(modelId, prompt, defaultOptions); + + return createConverseStreamRequest(bedrockConverseRequest); + } + + /** + * Convert {@link Prompt} to {@link ConverseStreamRequest} with model id and options. + * It will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + * + * @param bedrockConverseRequest The Amazon Bedrock Converse encapsulates request. + * @return Amazon Bedrock Converse stream request. + */ + public static ConverseStreamRequest createConverseStreamRequest(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null."); + + return ConverseStreamRequest.builder() + .modelId(bedrockConverseRequest.modelId()) + .messages(bedrockConverseRequest.messages()) + .system(bedrockConverseRequest.systemMessages()) + .additionalModelRequestFields(bedrockConverseRequest.additionalModelRequestFields()) + .toolConfig(bedrockConverseRequest.toolConfiguration()) + .build(); + } + + /** + * Convert {@link ConverseResponse} to {@link ChatResponse} includes model output, + * stopReason, usage, metrics etc. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax + * + * @param response The Bedrock Converse response. + * @return The ChatResponse entity. + */ + public static ChatResponse convertConverseResponse(ConverseResponse response) { + Assert.notNull(response, "'response' must not be null."); + + Message message = response.output().message(); + + String text = getConverseResponseTextContent(message.content()); + + Generation generation = new Generation(text) + .withGenerationMetadata(BedrockConverseChatGenerationMetadata.from(response, message)); + + return new ChatResponse(List.of(generation), BedrockChatResponseMetadata.from(response)); + } + + /** + * Convert {@link ConverseStreamOutput} to {@link ChatResponse} includes model output, + * stopReason, usage, metrics etc. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax + * + * @param output The Bedrock Converse stream output. + * @return The ChatResponse entity. + */ + public static ChatResponse convertConverseStreamOutput(ConverseStreamOutput output) { + if (output instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) { + Generation generation = new Generation(contentBlockDeltaEvent.delta().text()) + .withGenerationMetadata(BedrockConverseChatGenerationMetadata.from(contentBlockDeltaEvent)); + + return new ChatResponse(List.of(generation)); + } else if (output instanceof MessageStopEvent messageStopEvent) { + var metadata = BedrockConverseChatGenerationMetadata.from(messageStopEvent); + + return new ChatResponse(List.of(new Generation("").withGenerationMetadata(metadata))); + } else if (output instanceof ConverseStreamMetadataEvent converseStreamMetadataEvent) { + return new ChatResponse(List.of(), BedrockChatResponseMetadata.from(converseStreamMetadataEvent)); + } else { + Generation generation = new Generation("") + .withGenerationMetadata(BedrockConverseChatGenerationMetadata.from(output)); + + return new ChatResponse(List.of(generation)); + } + } + + private static String getConverseResponseTextContent(List contents) { + Optional optional = contents.stream().filter(content -> content.type() == Type.TEXT).findFirst(); + + return optional.isPresent() ? optional.get().text() : ""; + } + + private static List getPromptSystemContentBlocks(Prompt prompt) { + return prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() == MessageType.SYSTEM) + .map(instruction -> SystemContentBlock.builder().text(instruction.getContent()).build()) + .toList(); + } + + public static List getInstructionsMessages( + List instructions) { + return instructions.stream() + .filter(message -> message.getMessageType() == MessageType.USER + || message.getMessageType() == MessageType.ASSISTANT) + .map(instruction -> createMessage(getInstructionContents(instruction), + instruction.getMessageType() == MessageType.USER ? ConversationRole.USER + : ConversationRole.ASSISTANT)) + .toList(); + } + + public static Message createMessage(List contentBlocks, ConversationRole role) { + return Message.builder().content(contentBlocks).role(role).build(); + } + + private static List getInstructionContents(org.springframework.ai.chat.messages.Message instruction) { + List contents = new ArrayList<>(); + + ContentBlock textContentBlock = ContentBlock.builder().text(instruction.getContent()).build(); + + contents.add(textContentBlock); + + List mediaContentBlocks = instruction.getMedia() + .stream() + .map(media -> ContentBlock.builder() + .image(ImageBlock.builder() + .format(media.getMimeType().getSubtype()) + .source(ImageSource.fromBytes(SdkBytes.fromByteArray(getContentMediaData(media.getData())))) + .build()) + .build()) + .toList(); + + contents.addAll(mediaContentBlocks); + + return contents; + } + + private static byte[] getContentMediaData(Object mediaData) { + if (mediaData instanceof byte[] bytes) { + return bytes; + } else if (mediaData instanceof String text) { + return text.getBytes(); + } else { + throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName()); + } + } + + @SuppressWarnings("unchecked") + private static Document getChatOptionsAdditionalModelRequestFields(ChatOptions defaultOptions, + ModelOptions promptOptions) { + if (defaultOptions == null && promptOptions == null) { + return null; + } + + Map attributes = new HashMap<>(); + + if (defaultOptions != null) { + Map options = objectMapper.convertValue(defaultOptions, Map.class); + + attributes.putAll(options); + } + + if (promptOptions != null) { + if (promptOptions instanceof ChatOptions runtimeOptions) { + Map options = objectMapper.convertValue(runtimeOptions, Map.class); + + attributes.putAll(options); + } else { + throw new IllegalArgumentException( + "Prompt options are not of type ChatOptions:" + promptOptions.getClass().getSimpleName()); + } + } + + return convertObjectToDocument(attributes); + } + + @SuppressWarnings("unchecked") + public static Document convertObjectToDocument(Object value) { + if (value == null) { + return Document.fromNull(); + } else if (value instanceof String stringValue) { + return Document.fromString(stringValue); + } else if (value instanceof Boolean booleanValue) { + return Document.fromBoolean(booleanValue); + } else if (value instanceof Integer integerValue) { + return Document.fromNumber(integerValue); + } else if (value instanceof Long longValue) { + return Document.fromNumber(longValue); + } else if (value instanceof Float floatValue) { + return Document.fromNumber(floatValue); + } else if (value instanceof Double doubleValue) { + return Document.fromNumber(doubleValue); + } else if (value instanceof BigDecimal bigDecimalValue) { + return Document.fromNumber(bigDecimalValue); + } else if (value instanceof BigInteger bigIntegerValue) { + return Document.fromNumber(bigIntegerValue); + } else if (value instanceof List listValue) { + return Document.fromList(listValue.stream().map(v -> convertObjectToDocument(v)).toList()); + } else if (value instanceof Map mapValue) { + return convertMapToDocument(mapValue); + } else { + throw new IllegalArgumentException("Unsupported value type:" + value.getClass().getSimpleName()); + } + } + + private static Document convertMapToDocument(Map value) { + Map attr = value.entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey(), e -> convertObjectToDocument(e.getValue()))); + + return Document.fromMap(attr); + } + +} +//@formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index e9895fc1db..41019abb15 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -15,97 +15,74 @@ */ package org.springframework.ai.bedrock.cohere; -import java.util.List; - import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; -import org.springframework.ai.bedrock.BedrockUsage; -import org.springframework.ai.bedrock.MessageToPromptConverter; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.ModelDescription; import org.springframework.util.Assert; /** + * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Cohere chat + * generative model. + * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockCohereChatModel implements ChatModel, StreamingChatModel { - private final CohereChatBedrockApi chatApi; + private final String modelId; + + private final BedrockConverseApi converseApi; private final BedrockCohereChatOptions defaultOptions; - public BedrockCohereChatModel(CohereChatBedrockApi chatApi) { - this(chatApi, BedrockCohereChatOptions.builder().build()); + public BedrockCohereChatModel(BedrockConverseApi converseApi) { + this(converseApi, BedrockCohereChatOptions.builder().build()); } - public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) { - Assert.notNull(chatApi, "CohereChatBedrockApi must not be null"); + public BedrockCohereChatModel(BedrockConverseApi converseApi, BedrockCohereChatOptions options) { + this(CohereChatModel.COHERE_COMMAND_V14.id(), converseApi, options); + } + + public BedrockCohereChatModel(String modelId, BedrockConverseApi converseApi, BedrockCohereChatOptions options) { + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); Assert.notNull(options, "BedrockCohereChatOptions must not be null"); - this.chatApi = chatApi; + this.modelId = modelId; + this.converseApi = converseApi; this.defaultOptions = options; } @Override public ChatResponse call(Prompt prompt) { - CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); - List generations = response.generations().stream().map(g -> { - return new Generation(g.text()); - }).toList(); + Assert.notNull(prompt, "Prompt must not be null."); + + var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions); + + ConverseResponse response = this.converseApi.converse(request); - return new ChatResponse(generations); + return BedrockConverseApiUtils.convertConverseResponse(response); } @Override public Flux stream(Prompt prompt) { - return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(g -> { - if (g.isFinished()) { - String finishReason = g.finishReason().name(); - Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); - return new ChatResponse(List - .of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); - } - return new ChatResponse(List.of(new Generation(g.text()))); - }); - } + Assert.notNull(prompt, "Prompt must not be null."); - /** - * Test access. - */ - CohereChatRequest createRequest(Prompt prompt, boolean stream) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); - - var request = CohereChatRequest.builder(promptValue) - .withTemperature(this.defaultOptions.getTemperature()) - .withTopP(this.defaultOptions.getTopP()) - .withTopK(this.defaultOptions.getTopK()) - .withMaxTokens(this.defaultOptions.getMaxTokens()) - .withStopSequences(this.defaultOptions.getStopSequences()) - .withReturnLikelihoods(this.defaultOptions.getReturnLikelihoods()) - .withStream(stream) - .withNumGenerations(this.defaultOptions.getNumGenerations()) - .withLogitBias(this.defaultOptions.getLogitBias()) - .withTruncate(this.defaultOptions.getTruncate()) - .build(); - - if (prompt.getOptions() != null) { - BedrockCohereChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), - ChatOptions.class, BedrockCohereChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereChatRequest.class); - } + var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions); + + Flux fluxResponse = this.converseApi.converseStream(request); - return request; + return fluxResponse.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output)); } @Override @@ -113,4 +90,39 @@ public ChatOptions getDefaultOptions() { return BedrockCohereChatOptions.fromOptions(this.defaultOptions); } + /** + * Cohere models version. + */ + public enum CohereChatModel implements ModelDescription { + + /** + * cohere.command-light-text-v14 + */ + COHERE_COMMAND_LIGHT_V14("cohere.command-light-text-v14"), + + /** + * cohere.command-text-v14 + */ + COHERE_COMMAND_V14("cohere.command-text-v14"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; + } + + CohereChatModel(String value) { + this.id = value; + } + + @Override + public String getModelName() { + return this.id; + } + + } + } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java index e0ab181cc9..f2d88efb24 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java @@ -21,13 +21,15 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.LogitBias; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; import org.springframework.ai.chat.prompt.ChatOptions; /** + * Java {@link ChatOptions} for the Bedrock Cohere Command chat generative model chat + * options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) @@ -213,6 +215,59 @@ public void setTruncate(Truncate truncate) { this.truncate = truncate; } + /** + * Specify how and if the token likelihoods are returned with the response. + */ + public static enum ReturnLikelihoods { + + /** + * Only return likelihoods for generated tokens. + */ + GENERATION, + /** + * Return likelihoods for all tokens. + */ + ALL, + /** + * (Default) Don't return any likelihoods. + */ + NONE + + } + + /** + * Specifies how the API handles inputs longer than the maximum token length. If you + * specify START or END, the model discards the input until the remaining input is + * exactly the maximum input token length for the model. + */ + public enum Truncate { + + /** + * Returns an error when the input exceeds the maximum input token length. + */ + NONE, + /** + * Discard the start of the input. + */ + START, + /** + * (Default) Discards the end of the input. + */ + END + + } + + /** + * Prevents the model from generating unwanted tokens or incentivize the model to + * include desired tokens. + * + * @param token The token likelihoods. + * @param bias A float between -10 and 10. + */ + @JsonInclude(Include.NON_NULL) + public record LogitBias(@JsonProperty("token") Integer token, @JsonProperty("bias") Float bias) { + } + public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) { return builder().withTemperature(fromOptions.getTemperature()) .withTopP(fromOptions.getTopP()) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java new file mode 100644 index 0000000000..6395e09c92 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java @@ -0,0 +1,283 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; +import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.StopReason; +import software.amazon.awssdk.services.bedrockruntime.model.Tool; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultStatus; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; +import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type; + +/** + * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Command R chat + * generative model. + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockCohereCommandRChatModel + extends AbstractFunctionCallSupport + implements ChatModel, StreamingChatModel { + + private final String modelId; + + private final BedrockConverseApi converseApi; + + private final BedrockCohereCommandRChatOptions defaultOptions; + + public BedrockCohereCommandRChatModel(BedrockConverseApi converseApi) { + this(converseApi, BedrockCohereCommandRChatOptions.builder().build()); + } + + public BedrockCohereCommandRChatModel(BedrockConverseApi converseApi, BedrockCohereCommandRChatOptions options) { + this(CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), converseApi, options); + } + + public BedrockCohereCommandRChatModel(String modelId, BedrockConverseApi converseApi, + BedrockCohereCommandRChatOptions options) { + this(modelId, converseApi, options, null); + } + + public BedrockCohereCommandRChatModel(String modelId, BedrockConverseApi converseApi, + BedrockCohereCommandRChatOptions options, FunctionCallbackContext functionCallbackContext) { + super(functionCallbackContext); + + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null."); + + this.modelId = modelId; + this.converseApi = converseApi; + this.defaultOptions = options; + } + + @Override + public ChatResponse call(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); + + var request = createBedrockConverseRequest(prompt); + + return this.callWithFunctionSupport(request); + } + + @Override + public Flux stream(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); + + var request = createBedrockConverseRequest(prompt); + + return converseApi.converseStream(request); + } + + private BedrockConverseRequest createBedrockConverseRequest(Prompt prompt) { + var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions); + + ToolConfiguration toolConfiguration = createToolConfiguration(prompt); + + return BedrockConverseRequest.from(request).withToolConfiguration(toolConfiguration).build(); + } + + private ToolConfiguration createToolConfiguration(Prompt prompt) { + Set functionsForThisRequest = new HashSet<>(); + + if (this.defaultOptions != null) { + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + functionsForThisRequest.addAll(promptEnabledFunctions); + } + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, BedrockCohereCommandRChatOptions.class); + + Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, + IS_RUNTIME_CALL); + functionsForThisRequest.addAll(defaultEnabledFunctions); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + return ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build(); + } + + return null; + } + + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var description = functionCallback.getDescription(); + var name = functionCallback.getName(); + String inputSchema = functionCallback.getInputTypeSchema(); + + return Tool.builder() + .toolSpec(ToolSpecification.builder() + .name(name) + .description(description) + .inputSchema(ToolInputSchema.builder() + .json(BedrockConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema))) + .build()) + .build()) + .build(); + }).toList(); + } + + @Override + public ChatOptions getDefaultOptions() { + return BedrockCohereCommandRChatOptions.fromOptions(defaultOptions); + } + + @Override + protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest, + Message responseMessage, List conversationHistory) { + List toolToUseList = responseMessage.content() + .stream() + .filter(content -> content.type() == Type.TOOL_USE) + .map(content -> content.toolUse()) + .toList(); + + List toolResults = new ArrayList<>(); + + for (ToolUseBlock toolToUse : toolToUseList) { + var functionCallId = toolToUse.toolUseId(); + var functionName = toolToUse.name(); + var functionArguments = toolToUse.input().unwrap(); + + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName) + .call(ModelOptionsUtils.toJsonString(functionArguments)); + + toolResults.add(ToolResultBlock.builder() + .toolUseId(functionCallId) + .status(ToolResultStatus.SUCCESS) + .content(ToolResultContentBlock.builder().text(functionResponse).build()) + .build()); + } + + // Add the function response to the conversation. + Message toolResultMessage = Message.builder() + .content(toolResults.stream().map(toolResult -> ContentBlock.fromToolResult(toolResult)).toList()) + .role(ConversationRole.USER) + .build(); + conversationHistory.add(toolResultMessage); + + // Recursively call chatCompletionWithTools until the model doesn't call a + // functions anymore. + return BedrockConverseRequest.from(previousRequest).withMessages(conversationHistory).build(); + } + + @Override + protected List doGetUserMessages(BedrockConverseRequest request) { + return request.messages(); + } + + @Override + protected Message doGetToolResponseMessage(ChatResponse response) { + Generation result = response.getResult(); + + var metadata = (BedrockConverseChatGenerationMetadata) result.getMetadata(); + + return metadata.getMessage(); + } + + @Override + protected ChatResponse doChatCompletion(BedrockConverseRequest request) { + return converseApi.converse(request); + } + + @Override + protected Flux doChatCompletionStream(BedrockConverseRequest request) { + throw new UnsupportedOperationException("Streaming function calling is not supported."); + } + + @Override + protected boolean isToolFunctionCall(ChatResponse response) { + Generation result = response.getResult(); + if (result == null) { + return false; + } + + return StopReason.fromValue(result.getMetadata().getFinishReason()) == StopReason.TOOL_USE; + } + + /** + * Cohere command R models version. + */ + public enum CohereCommandRChatModel { + + /** + * cohere.command-r-v1:0 + */ + COHERE_COMMAND_R_V1("cohere.command-r-v1:0"), + + /** + * cohere.command-r-plus-v1:0 + */ + COHERE_COMMAND_R_PLUS_V1("cohere.command-r-plus-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; + } + + CohereCommandRChatModel(String value) { + this.id = value; + } + + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java new file mode 100644 index 0000000000..9aa18849ea --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java @@ -0,0 +1,388 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonInclude.Include; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + +/** + * Java {@link ChatOptions} for the Bedrock Cohere Command R chat generative model chat + * options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + * + * @author Wei Jiang + * @since 1.0.0 + */ +@JsonInclude(Include.NON_NULL) +public class BedrockCohereCommandRChatOptions implements ChatOptions, FunctionCallingOptions { + + // @formatter:off + /** + * (optional) When enabled, it will only generate potential search queries without performing + * searches or providing a response. + */ + @JsonProperty("search_queries_only") Boolean searchQueriesOnly; + /** + * (optional) Overrides the default preamble for search query generation. + */ + @JsonProperty("preamble") String preamble; + /** + * (optional) Specify the maximum number of tokens to use in the generated response. + */ + @JsonProperty("max_tokens") Integer maxTokens; + /** + * (optional) Use a lower value to decrease randomness in the response. + */ + @JsonProperty("temperature") Float temperature; + /** + * Top P. Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable. + */ + @JsonProperty("p") Float topP; + /** + * Top K. Specify the number of token choices the model uses to generate the next token. + */ + @JsonProperty("k") Integer topK; + /** + * (optional) Dictates how the prompt is constructed. + */ + @JsonProperty("prompt_truncation") PromptTruncation promptTruncation; + /** + * (optional) Used to reduce repetitiveness of generated tokens. + */ + @JsonProperty("frequency_penalty") Float frequencyPenalty; + /** + * (optional) Used to reduce repetitiveness of generated tokens. + */ + @JsonProperty("presence_penalty") Float presencePenalty; + /** + * (optional) Specify the best effort to sample tokens deterministically. + */ + @JsonProperty("seed") Integer seed; + /** + * (optional) Specify true to return the full prompt that was sent to the model. + */ + @JsonProperty("return_prompt") Boolean returnPrompt; + /** + * (optional) A list of stop sequences. + */ + @JsonProperty("stop_sequences") List stopSequences; + /** + * (optional) Specify true, to send the user’s message to the model without any preprocessing. + */ + @JsonProperty("raw_prompting") Boolean rawPrompting; + + /** + * Tool Function Callbacks to register with the ChatModel. For Prompt Options the + * functionCallbacks are automatically enabled for the duration of the prompt + * execution. For Default Options the functionCallbacks are registered but disabled by + * default. Use the enableFunctions to set the functions from the registry to be used + * by the ChatModel chat completion requests. + */ + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. Functions with those names must exist in the + * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions + * are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat + * completion requests. This could impact the token count and the billing. If the + * functions is set in a prompt options, then the enabled functions are only active + * for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final BedrockCohereCommandRChatOptions options = new BedrockCohereCommandRChatOptions(); + + public Builder withSearchQueriesOnly(Boolean searchQueriesOnly) { + options.setSearchQueriesOnly(searchQueriesOnly); + return this; + } + + public Builder withPreamble(String preamble) { + options.setPreamble(preamble); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + options.setMaxTokens(maxTokens); + return this; + } + + public Builder withTemperature(Float temperature) { + options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Float topP) { + options.setTopP(topP); + return this; + } + + public Builder withTopK(Integer topK) { + options.setTopK(topK); + return this; + } + + public Builder withPromptTruncation(PromptTruncation promptTruncation) { + options.setPromptTruncation(promptTruncation); + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + options.setPresencePenalty(presencePenalty); + return this; + } + + public Builder withSeed(Integer seed) { + options.setSeed(seed); + return this; + } + + public Builder withReturnPrompt(Boolean returnPrompt) { + options.setReturnPrompt(returnPrompt); + return this; + } + + public Builder withStopSequences(List stopSequences) { + options.setStopSequences(stopSequences); + return this; + } + + public Builder withRawPrompting(Boolean rawPrompting) { + options.setRawPrompting(rawPrompting); + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public BedrockCohereCommandRChatOptions build() { + return this.options; + } + + } + + public Boolean getSearchQueriesOnly() { + return searchQueriesOnly; + } + + public void setSearchQueriesOnly(Boolean searchQueriesOnly) { + this.searchQueriesOnly = searchQueriesOnly; + } + + public String getPreamble() { + return preamble; + } + + public void setPreamble(String preamble) { + this.preamble = preamble; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public PromptTruncation getPromptTruncation() { + return promptTruncation; + } + + public void setPromptTruncation(PromptTruncation promptTruncation) { + this.promptTruncation = promptTruncation; + } + + public Float getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Float getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public Boolean getReturnPrompt() { + return returnPrompt; + } + + public void setReturnPrompt(Boolean returnPrompt) { + this.returnPrompt = returnPrompt; + } + + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + public Boolean getRawPrompting() { + return rawPrompting; + } + + public void setRawPrompting(Boolean rawPrompting) { + this.rawPrompting = rawPrompting; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + Assert.notNull(functions, "Function must not be null"); + this.functions = functions; + } + + /** + * Specifies how the prompt is constructed. + */ + public enum PromptTruncation { + + /** + * Some elements from chat_history and documents will be dropped to construct a + * prompt that fits within the model's context length limit. + */ + AUTO_PRESERVE_ORDER, + /** + * (Default) No elements will be dropped. + */ + OFF + + } + + public static BedrockCohereCommandRChatOptions fromOptions(BedrockCohereCommandRChatOptions fromOptions) { + return builder().withSearchQueriesOnly(fromOptions.getSearchQueriesOnly()) + .withPreamble(fromOptions.getPreamble()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withPromptTruncation(fromOptions.getPromptTruncation()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withSeed(fromOptions.getSeed()) + .withReturnPrompt(fromOptions.getReturnPrompt()) + .withStopSequences(fromOptions.getStopSequences()) + .withRawPrompting(fromOptions.getRawPrompting()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .build(); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java index 25e3f35b43..b074cc9c4b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java @@ -28,6 +28,8 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -44,6 +46,11 @@ public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel { private final BedrockCohereEmbeddingOptions defaultOptions; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + // private CohereEmbeddingRequest.InputType inputType = // CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT; @@ -60,10 +67,18 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi, BedrockCohereEmbeddingOptions options) { + this(cohereEmbeddingBedrockApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi, + BedrockCohereEmbeddingOptions options, RetryTemplate retryTemplate) { Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null"); - Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null"); + Assert.notNull(options, "DefaultOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.embeddingApi = cohereEmbeddingBedrockApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } // /** @@ -104,13 +119,16 @@ public EmbeddingResponse call(EmbeddingRequest request) { var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(), optionsToUse.getTruncate()); - CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); - var indexCounter = new AtomicInteger(0); - List embeddings = apiResponse.embeddings() - .stream() - .map(e -> new Embedding(e, indexCounter.getAndIncrement())) - .toList(); - return new EmbeddingResponse(embeddings); + + return this.retryTemplate.execute(ctx -> { + CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); + var indexCounter = new AtomicInteger(0); + List embeddings = apiResponse.embeddings() + .stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); + }); } /** diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java deleted file mode 100644 index 766271b87c..0000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ /dev/null @@ -1,413 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// @formatter:off -package org.springframework.ai.bedrock.cohere.api; - -import java.time.Duration; -import java.util.List; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; -import org.springframework.ai.model.ModelDescription; -import org.springframework.util.Assert; - -/** - * Java client for the Bedrock Cohere chat model. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html - * - * @author Christian Tzolov - * @author Wei Jiang - * @since 0.8.0 - */ -public class CohereChatBedrockApi extends - AbstractBedrockApi { - - /** - * Create a new CohereChatBedrockApi instance using the default credentials provider chain, the default object - * mapper, default temperature and topP values. - * - * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. - * @param region The AWS region to use. - */ - public CohereChatBedrockApi(String modelId, String region) { - super(modelId, region); - } - - /** - * Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - */ - public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { - super(modelId, credentialsProvider, region, objectMapper); - } - - /** - * Create a new CohereChatBedrockApi instance using the default credentials provider chain, the default object - * mapper, default temperature and topP values. - * - * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. - * @param region The AWS region to use. - * @param timeout The timeout to use. - */ - public CohereChatBedrockApi(String modelId, String region, Duration timeout) { - super(modelId, region, timeout); - } - - /** - * Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * CohereChatRequest encapsulates the request parameters for the Cohere command model. - * - * @param prompt The input prompt to generate the response from. - * @param temperature (optional) Use a lower value to decrease randomness in the response. - * @param topP (optional) Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable. - * @param topK (optional) Specify the number of token choices the model uses to generate the next token. - * @param maxTokens (optional) Specify the maximum number of tokens to use in the generated response. - * @param stopSequences (optional) Configure up to four sequences that the model recognizes. After a stop sequence, - * the model stops generating further tokens. The returned text doesn't contain the stop sequence. - * @param returnLikelihoods (optional) Specify how and if the token likelihoods are returned with the response. - * @param stream (optional) Specify true to return the response piece-by-piece in real-time and false to return the - * complete response after the process finishes. - * @param numGenerations (optional) The maximum number of generations that the model should return. - * @param logitBias (optional) prevents the model from generating unwanted tokens or incentivize the model to - * include desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. Tokens can be - * obtained from text using any tokenization service, such as Cohere’s Tokenize endpoint. - * @param truncate (optional) Specifies how the API handles inputs longer than the maximum token length. - */ - @JsonInclude(Include.NON_NULL) - public record CohereChatRequest( - @JsonProperty("prompt") String prompt, - @JsonProperty("temperature") Float temperature, - @JsonProperty("p") Float topP, - @JsonProperty("k") Integer topK, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("stop_sequences") List stopSequences, - @JsonProperty("return_likelihoods") ReturnLikelihoods returnLikelihoods, - @JsonProperty("stream") boolean stream, - @JsonProperty("num_generations") Integer numGenerations, - @JsonProperty("logit_bias") LogitBias logitBias, - @JsonProperty("truncate") Truncate truncate) { - - /** - * Prevents the model from generating unwanted tokens or incentivize the model to include desired tokens. - * - * @param token The token likelihoods. - * @param bias A float between -10 and 10. - */ - @JsonInclude(Include.NON_NULL) - public record LogitBias( - @JsonProperty("token") String token, - @JsonProperty("bias") Float bias) { - } - - /** - * (optional) Specify how and if the token likelihoods are returned with the response. - */ - public enum ReturnLikelihoods { - /** - * Only return likelihoods for generated tokens. - */ - GENERATION, - /** - * Return likelihoods for all tokens. - */ - ALL, - /** - * (Default) Don't return any likelihoods. - */ - NONE - } - - /** - * Specifies how the API handles inputs longer than the maximum token length. If you specify START or END, the - * model discards the input until the remaining input is exactly the maximum input token length for the model. - */ - public enum Truncate { - /** - * Returns an error when the input exceeds the maximum input token length. - */ - NONE, - /** - * Discard the start of the input. - */ - START, - /** - * (Default) Discards the end of the input. - */ - END - } - - /** - * Get CohereChatRequest builder. - * @param prompt compulsory request prompt parameter. - * @return CohereChatRequest builder. - */ - public static Builder builder(String prompt) { - return new Builder(prompt); - } - - /** - * Builder for the CohereChatRequest. - */ - public static class Builder { - private final String prompt; - private Float temperature; - private Float topP; - private Integer topK; - private Integer maxTokens; - private List stopSequences; - private ReturnLikelihoods returnLikelihoods; - private boolean stream; - private Integer numGenerations; - private LogitBias logitBias; - private Truncate truncate; - - public Builder(String prompt) { - this.prompt = prompt; - } - - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder withTopP(Float topP) { - this.topP = topP; - return this; - } - - public Builder withTopK(Integer topK) { - this.topK = topK; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder withReturnLikelihoods(ReturnLikelihoods returnLikelihoods) { - this.returnLikelihoods = returnLikelihoods; - return this; - } - - public Builder withStream(boolean stream) { - this.stream = stream; - return this; - } - - public Builder withNumGenerations(Integer numGenerations) { - this.numGenerations = numGenerations; - return this; - } - - public Builder withLogitBias(LogitBias logitBias) { - this.logitBias = logitBias; - return this; - } - - public Builder withTruncate(Truncate truncate) { - this.truncate = truncate; - return this; - } - - public CohereChatRequest build() { - return new CohereChatRequest( - prompt, - temperature, - topP, - topK, - maxTokens, - stopSequences, - returnLikelihoods, - stream, - numGenerations, - logitBias, - truncate - ); - } - } - } - - /** - * CohereChatResponse encapsulates the response parameters for the Cohere command model. - * - * @param id An identifier for the request (always returned). - * @param prompt The prompt from the input request. (Always returned). - * @param generations A list of generated results along with the likelihoods for tokens requested. (Always - * returned). - */ - @JsonInclude(Include.NON_NULL) - public record CohereChatResponse( - @JsonProperty("id") String id, - @JsonProperty("prompt") String prompt, - @JsonProperty("generations") List generations) { - - /** - * Generated result along with the likelihoods for tokens requested. - * - * @param id An identifier for the generation. (Always returned). - * @param likelihood The likelihood of the output. The value is the average of the token likelihoods in - * token_likelihoods. Returned if you specify the return_likelihoods input parameter. - * @param tokenLikelihoods An array of per token likelihoods. Returned if you specify the return_likelihoods - * input parameter. - * @param finishReason states the reason why the model finished generating tokens. - * @param isFinished A boolean field used only when stream is true, signifying whether or not there are - * additional tokens that will be generated as part of the streaming response. (Not always returned). - * @param text The generated text. - * @param index In a streaming response, use to determine which generation a given token belongs to. When only - * one response is streamed, all tokens belong to the same generation and index is not returned. index therefore - * is only returned in a streaming request with a value for num_generations that is larger than one. - * @param amazonBedrockInvocationMetrics Encapsulates the metrics about the model invocation. - */ - @JsonInclude(Include.NON_NULL) - public record Generation( - @JsonProperty("id") String id, - @JsonProperty("likelihood") Float likelihood, - @JsonProperty("token_likelihoods") List tokenLikelihoods, - @JsonProperty("finish_reason") FinishReason finishReason, - @JsonProperty("is_finished") Boolean isFinished, - @JsonProperty("text") String text, - @JsonProperty("index") Integer index, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - - /** - * @param token The token. - * @param likelihood The likelihood of the token. - */ - @JsonInclude(Include.NON_NULL) - public record TokenLikelihood( - @JsonProperty("token") String token, - @JsonProperty("likelihood") Float likelihood) { - } - - /** - * The reason the response finished being generated. - */ - public enum FinishReason { - /** - * The model sent back a finished reply. - */ - COMPLETE, - /** - * The reply was cut off because the model reached the maximum number of tokens for its context length. - */ - MAX_TOKENS, - /** - * Something went wrong when generating the reply. - */ - ERROR, - /** - * the model generated a reply that was deemed toxic. finish_reason is returned only when - * is_finished=true. (Not always returned). - */ - ERROR_TOXIC - } - } - } - - /** - * Cohere models version. - */ - public enum CohereChatModel implements ModelDescription { - - /** - * cohere.command-light-text-v14 - */ - COHERE_COMMAND_LIGHT_V14("cohere.command-light-text-v14"), - - /** - * cohere.command-text-v14 - */ - COHERE_COMMAND_V14("cohere.command-text-v14"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - CohereChatModel(String value) { - this.id = value; - } - - @Override - public String getModelName() { - return this.id; - } - } - - @Override - public CohereChatResponse chatCompletion(CohereChatRequest request) { - Assert.isTrue(!request.stream(), "The request must be configured to return the complete response!"); - return this.internalInvocation(request, CohereChatResponse.class); - } - - @Override - public Flux chatCompletionStream(CohereChatRequest request) { - Assert.isTrue(request.stream(), "The request must be configured to stream the response!"); - return this.internalInvocationStream(request, CohereChatResponse.Generation.class); - } -} -// @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java index 1ae4242275..a4022f8761 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java @@ -25,6 +25,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; @@ -109,6 +111,19 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new CohereEmbeddingBedrockApi instance using the provided AWS Bedrock clients, region and object mapper. + * + * @param modelId The model id to use. + * @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance. + * @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + */ + public CohereEmbeddingBedrockApi(String model, BedrockRuntimeClient bedrockRuntimeClient, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ObjectMapper objectMapper) { + super(model, bedrockRuntimeClient, bedrockRuntimeAsyncClient, objectMapper); + } + /** * The Cohere Embed model request. * diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java index b6f9d6cbc5..f18d9957c2 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java @@ -16,40 +16,34 @@ package org.springframework.ai.bedrock.jurassic2; -import org.springframework.ai.bedrock.MessageToPromptConverter; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.ModelDescription; import org.springframework.util.Assert; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; + /** * Java {@link ChatModel} for the Bedrock Jurassic2 chat generative model. * * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ public class BedrockAi21Jurassic2ChatModel implements ChatModel { - private final Ai21Jurassic2ChatBedrockApi chatApi; - - private final BedrockAi21Jurassic2ChatOptions defaultOptions; + private final String modelId; - public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi, BedrockAi21Jurassic2ChatOptions options) { - Assert.notNull(chatApi, "Ai21Jurassic2ChatBedrockApi must not be null"); - Assert.notNull(options, "BedrockAi21Jurassic2ChatOptions must not be null"); + private final BedrockConverseApi converseApi; - this.chatApi = chatApi; - this.defaultOptions = options; - } + private final BedrockAi21Jurassic2ChatOptions defaultOptions; - public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) { - this(chatApi, + public BedrockAi21Jurassic2ChatModel(BedrockConverseApi converseApi) { + this(converseApi, BedrockAi21Jurassic2ChatOptions.builder() .withTemperature(0.8f) .withTopP(0.9f) @@ -57,66 +51,70 @@ public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) { .build()); } - @Override - public ChatResponse call(Prompt prompt) { - var request = createRequest(prompt); - var response = this.chatApi.chatCompletion(request); - - return new ChatResponse(response.completions() - .stream() - .map(completion -> new Generation(completion.data().text()) - .withGenerationMetadata(ChatGenerationMetadata.from(completion.finishReason().reason(), null))) - .toList()); + public BedrockAi21Jurassic2ChatModel(BedrockConverseApi converseApi, BedrockAi21Jurassic2ChatOptions options) { + this(Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(), converseApi, options); } - private Ai21Jurassic2ChatRequest createRequest(Prompt prompt) { + public BedrockAi21Jurassic2ChatModel(String modelId, BedrockConverseApi converseApi, + BedrockAi21Jurassic2ChatOptions options) { + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(options, "BedrockAi21Jurassic2ChatOptions must not be null."); - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); + this.modelId = modelId; + this.converseApi = converseApi; + this.defaultOptions = options; + } - Ai21Jurassic2ChatRequest request = Ai21Jurassic2ChatRequest.builder(promptValue).build(); + @Override + public ChatResponse call(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - if (prompt.getOptions() != null) { - BedrockAi21Jurassic2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), - ChatOptions.class, BedrockAi21Jurassic2ChatOptions.class); - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Ai21Jurassic2ChatRequest.class); - } + var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions); - if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, Ai21Jurassic2ChatRequest.class); - } + ConverseResponse response = this.converseApi.converse(request); - return request; + return BedrockConverseApiUtils.convertConverseResponse(response); } - public static Builder builder(Ai21Jurassic2ChatBedrockApi chatApi) { - return new Builder(chatApi); + @Override + public ChatOptions getDefaultOptions() { + return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions); } - public static class Builder { + /** + * Ai21 Jurassic2 models version. + */ + public enum Ai21Jurassic2ChatModel implements ModelDescription { - private final Ai21Jurassic2ChatBedrockApi chatApi; + /** + * ai21.j2-mid-v1 + */ + AI21_J2_MID_V1("ai21.j2-mid-v1"), - private BedrockAi21Jurassic2ChatOptions options; + /** + * ai21.j2-ultra-v1 + */ + AI21_J2_ULTRA_V1("ai21.j2-ultra-v1"); - public Builder(Ai21Jurassic2ChatBedrockApi chatApi) { - this.chatApi = chatApi; - } + private final String id; - public Builder withOptions(BedrockAi21Jurassic2ChatOptions options) { - this.options = options; - return this; + /** + * @return The model id. + */ + public String id() { + return id; } - public BedrockAi21Jurassic2ChatModel build() { - return new BedrockAi21Jurassic2ChatModel(chatApi, - options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build()); + Ai21Jurassic2ChatModel(String value) { + this.id = value; } - } + @Override + public String getModelName() { + return this.id; + } - @Override - public ChatOptions getDefaultOptions() { - return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java index c165c61c1e..bab4d0b03b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatOptions.java @@ -18,23 +18,22 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + import org.springframework.ai.chat.prompt.ChatOptions; /** - * Request body for the /complete endpoint of the Jurassic-2 API. + * Java {@link ChatOptions} for the Bedrock Jurassic-2 chat generative model chat options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html * * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ @JsonInclude(JsonInclude.Include.NON_NULL) public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { - /** - * The text which the model is requested to continue. - */ - @JsonProperty("prompt") - private String prompt; - /** * Number of completions to sample and return. */ @@ -75,7 +74,7 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { * Stops decoding if any of the strings is generated. */ @JsonProperty("stopSequences") - private String[] stopSequences; + private List stopSequences; /** * Penalty object for frequency. @@ -97,22 +96,6 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions { // Getters and setters - /** - * Gets the prompt text for the model to continue. - * @return The prompt text. - */ - public String getPrompt() { - return prompt; - } - - /** - * Sets the prompt text for the model to continue. - * @param prompt The prompt text. - */ - public void setPrompt(String prompt) { - this.prompt = prompt; - } - /** * Gets the number of completions to sample and return. * @return The number of results. @@ -216,7 +199,7 @@ public void setTopK(Integer topK) { * Gets the stop sequences for stopping decoding if any of the strings is generated. * @return The stop sequences. */ - public String[] getStopSequences() { + public List getStopSequences() { return stopSequences; } @@ -224,7 +207,7 @@ public String[] getStopSequences() { * Sets the stop sequences for stopping decoding if any of the strings is generated. * @param stopSequences The stop sequences. */ - public void setStopSequences(String[] stopSequences) { + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } @@ -284,11 +267,6 @@ public static class Builder { private final BedrockAi21Jurassic2ChatOptions request = new BedrockAi21Jurassic2ChatOptions(); - public Builder withPrompt(String prompt) { - request.setPrompt(prompt); - return this; - } - public Builder withNumResults(Integer numResults) { request.setNumResults(numResults); return this; @@ -314,7 +292,7 @@ public Builder withTopP(Float topP) { return this; } - public Builder withStopSequences(String[] stopSequences) { + public Builder withStopSequences(List stopSequences) { request.setStopSequences(stopSequences); return this; } @@ -414,8 +392,7 @@ public Penalty build() { } public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) { - return builder().withPrompt(fromOptions.getPrompt()) - .withNumResults(fromOptions.getNumResults()) + return builder().withNumResults(fromOptions.getNumResults()) .withMaxTokens(fromOptions.getMaxTokens()) .withMinTokens(fromOptions.getMinTokens()) .withTemperature(fromOptions.getTemperature()) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java deleted file mode 100644 index fecf70fa4e..0000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ /dev/null @@ -1,413 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// @formatter:off -package org.springframework.ai.bedrock.jurassic2.api; - -import java.time.Duration; -import java.util.List; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; -import org.springframework.ai.model.ModelDescription; - -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -/** - * Java client for the Bedrock Jurassic2 chat model. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html - * - * @author Christian Tzolov - * @author Wei Jiang - * @since 0.8.0 - */ -public class Ai21Jurassic2ChatBedrockApi extends - AbstractBedrockApi { - - /** - * Create a new Ai21Jurassic2ChatBedrockApi instance using the default credentials provider chain, the default - * object mapper, default temperature and topP values. - * - * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatModel} for the supported models. - * @param region The AWS region to use. - */ - public Ai21Jurassic2ChatBedrockApi(String modelId, String region) { - super(modelId, region); - } - - - /** - * Create a new Ai21Jurassic2ChatBedrockApi instance. - * - * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - */ - public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { - super(modelId, credentialsProvider, region, objectMapper); - } - - /** - * Create a new Ai21Jurassic2ChatBedrockApi instance using the default credentials provider chain, the default - * object mapper, default temperature and topP values. - * - * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatModel} for the supported models. - * @param region The AWS region to use. - * @param timeout The timeout to use. - */ - public Ai21Jurassic2ChatBedrockApi(String modelId, String region, Duration timeout) { - super(modelId, region, timeout); - } - - - /** - * Create a new Ai21Jurassic2ChatBedrockApi instance. - * - * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * Create a new Ai21Jurassic2ChatBedrockApi instance. - * - * @param modelId The model id to use. See the {@link Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * AI21 Jurassic2 chat request parameters. - * - * @param prompt The prompt to use for the chat. - * @param temperature The temperature value controls the randomness of the generated text. - * @param topP The topP value controls the diversity of the generated text. Use a lower value to ignore less - * probable options. - * @param maxTokens Specify the maximum number of tokens to use in the generated response. - * @param stopSequences Configure stop sequences that the model recognizes and after which it stops generating - * further tokens. Press the Enter key to insert a newline character in a stop sequence. Use the Tab key to finish - * inserting a stop sequence. - * @param countPenalty Control repetition in the generated response. Use a higher value to lower the probability of - * generating new tokens that already appear at least once in the prompt or in the completion. Proportional to the - * number of appearances. - * @param presencePenalty Control repetition in the generated response. Use a higher value to lower the probability - * of generating new tokens that already appear at least once in the prompt or in the completion. - * @param frequencyPenalty Control repetition in the generated response. Use a high value to lower the probability - * of generating new tokens that already appear at least once in the prompt or in the completion. The value is - * proportional to the frequency of the token appearances (normalized to text length). - */ - @JsonInclude(Include.NON_NULL) - public record Ai21Jurassic2ChatRequest( - @JsonProperty("prompt") String prompt, - @JsonProperty("temperature") Float temperature, - @JsonProperty("topP") Float topP, - @JsonProperty("maxTokens") Integer maxTokens, - @JsonProperty("stopSequences") List stopSequences, - @JsonProperty("countPenalty") IntegerScalePenalty countPenalty, - @JsonProperty("presencePenalty") FloatScalePenalty presencePenalty, - @JsonProperty("frequencyPenalty") IntegerScalePenalty frequencyPenalty) { - - /** - * Penalty with integer scale value. - * - * @param scale The scale value controls the strength of the penalty. Use a higher value to lower the - * probability of generating new tokens that already appear at least once in the prompt or in the completion. - * @param applyToWhitespaces Reduce the probability of repetition of special characters. A true value applies - * the penalty to whitespaces and new lines. - * @param applyToPunctuations Reduce the probability of repetition of special characters. A true value applies - * the penalty to punctuations. - * @param applyToNumbers Reduce the probability of repetition of special characters. A true value applies the - * penalty to numbers. - * @param applyToStopwords Reduce the probability of repetition of special characters. A true value applies the - * penalty to stopwords. - * @param applyToEmojis Reduce the probability of repetition of special characters. A true value applies the - * penalty to emojis. - */ - @JsonInclude(Include.NON_NULL) - public record IntegerScalePenalty( - @JsonProperty("scale") Integer scale, - @JsonProperty("applyToWhitespaces") boolean applyToWhitespaces, - @JsonProperty("applyToPunctuations") boolean applyToPunctuations, - @JsonProperty("applyToNumbers") boolean applyToNumbers, - @JsonProperty("applyToStopwords") boolean applyToStopwords, - @JsonProperty("applyToEmojis") boolean applyToEmojis) { - } - - /** - * Penalty with float scale value. - * - * @param scale The scale value controls the strength of the penalty. Use a higher value to lower the - * probability of generating new tokens that already appear at least once in the prompt or in the completion. - * @param applyToWhitespaces Reduce the probability of repetition of special characters. A true value applies - * the penalty to whitespaces and new lines. - * @param applyToPunctuations Reduce the probability of repetition of special characters. A true value applies - * the penalty to punctuations. - * @param applyToNumbers Reduce the probability of repetition of special characters. A true value applies the - * penalty to numbers. - * @param applyToStopwords Reduce the probability of repetition of special characters. A true value applies the - * penalty to stopwords. - * @param applyToEmojis Reduce the probability of repetition of special characters. A true value applies the - * penalty to emojis. - */ - @JsonInclude(Include.NON_NULL) - public record FloatScalePenalty(@JsonProperty("scale") Float scale, - @JsonProperty("applyToWhitespaces") boolean applyToWhitespaces, - @JsonProperty("applyToPunctuations") boolean applyToPunctuations, - @JsonProperty("applyToNumbers") boolean applyToNumbers, - @JsonProperty("applyToStopwords") boolean applyToStopwords, - @JsonProperty("applyToEmojis") boolean applyToEmojis) { - } - - - - public static Builder builder(String prompt) { - return new Builder(prompt); - } - public static class Builder { - private String prompt; - private Float temperature; - private Float topP; - private Integer maxTokens; - private List stopSequences; - private IntegerScalePenalty countPenalty; - private FloatScalePenalty presencePenalty; - private IntegerScalePenalty frequencyPenalty; - - public Builder(String prompt) { - this.prompt = prompt; - } - - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder withTopP(Float topP) { - this.topP = topP; - return this; - } - - public Builder withMaxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public Builder withCountPenalty(IntegerScalePenalty countPenalty) { - this.countPenalty = countPenalty; - return this; - } - - public Builder withPresencePenalty(FloatScalePenalty presencePenalty) { - this.presencePenalty = presencePenalty; - return this; - } - - public Builder withFrequencyPenalty(IntegerScalePenalty frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - return this; - } - - public Ai21Jurassic2ChatRequest build() { - return new Ai21Jurassic2ChatRequest( - prompt, - temperature, - topP, - maxTokens, - stopSequences, - countPenalty, - presencePenalty, - frequencyPenalty - ); - } - } - } - - /** - * Ai21 Jurassic2 chat response. - * https://docs.ai21.com/reference/j2-complete-api-ref#response - * - * @param id The unique identifier of the response. - * @param prompt The prompt used for the chat. - * @param amazonBedrockInvocationMetrics The metrics about the model invocation. - */ - @JsonInclude(Include.NON_NULL) - public record Ai21Jurassic2ChatResponse( - @JsonProperty("id") String id, - @JsonProperty("prompt") Prompt prompt, - @JsonProperty("completions") List completions, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - - /** - */ - @JsonInclude(Include.NON_NULL) - public record Completion( - @JsonProperty("data") Prompt data, - @JsonProperty("finishReason") FinishReason finishReason) { - } - - /** - * Provides detailed information about each token in both the prompt and the completions. - * - * @param generatedToken The generatedToken fields. - * @param topTokens The topTokens field is a list of the top K alternative tokens for this position, sorted by - * probability, according to the topKReturn request parameter. If topKReturn is set to 0, this field will be - * null. - * @param textRange The textRange field indicates the start and end offsets of the token in the decoded text - * string. - */ - @JsonInclude(Include.NON_NULL) - public record Token( - @JsonProperty("generatedToken") GeneratedToken generatedToken, - @JsonProperty("topTokens") List topTokens, - @JsonProperty("textRange") TextRange textRange) { - } - - /** - * The generatedToken fields. - * - * @param token TThe string representation of the token. - * @param logprob The predicted log probability of the token after applying the sampling parameters as a float - * value. - * @param rawLogprob The raw predicted log probability of the token as a float value. For the indifferent values - * (namely, temperature=1, topP=1) we get raw_logprob=logprob. - */ - @JsonInclude(Include.NON_NULL) - public record GeneratedToken( - @JsonProperty("token") String token, - @JsonProperty("logprob") Float logprob, - @JsonProperty("raw_logprob") Float rawLogprob) { - - } - - /** - * The topTokens field is a list of the top K alternative tokens for this position, sorted by probability, - * according to the topKReturn request parameter. If topKReturn is set to 0, this field will be null. - * - * @param token The string representation of the alternative token. - * @param logprob The predicted log probability of the alternative token. - */ - @JsonInclude(Include.NON_NULL) - public record TopToken( - @JsonProperty("token") String token, - @JsonProperty("logprob") Float logprob) { - } - - /** - * The textRange field indicates the start and end offsets of the token in the decoded text string. - * - * @param start The starting index of the token in the decoded text string. - * @param end The ending index of the token in the decoded text string. - */ - @JsonInclude(Include.NON_NULL) - public record TextRange( - @JsonProperty("start") Integer start, - @JsonProperty("end") Integer end) { - } - - /** - * The prompt includes the raw text, the tokens with their log probabilities, and the top-K alternative tokens - * at each position, if requested. - * - * @param text The raw text of the prompt. - * @param tokens Provides detailed information about each token in both the prompt and the completions. - */ - @JsonInclude(Include.NON_NULL) - public record Prompt( - @JsonProperty("text") String text, - @JsonProperty("tokens") List tokens) { - } - - /** - * Explains why the generation process was halted for a specific completion. - * - * @param reason The reason field indicates the reason for the completion to stop. - * - */ - @JsonInclude(Include.NON_NULL) - public record FinishReason( - @JsonProperty("reason") String reason, - @JsonProperty("length") String length, - @JsonProperty("sequence") String sequence) { - } - } - - /** - * Ai21 Jurassic2 models version. - */ - public enum Ai21Jurassic2ChatModel implements ModelDescription { - - /** - * ai21.j2-mid-v1 - */ - AI21_J2_MID_V1("ai21.j2-mid-v1"), - - /** - * ai21.j2-ultra-v1 - */ - AI21_J2_ULTRA_V1("ai21.j2-ultra-v1"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - Ai21Jurassic2ChatModel(String value) { - this.id = value; - } - - @Override - public String getModelName() { - return this.id; - } - } - - @Override - public Ai21Jurassic2ChatResponse chatCompletion(Ai21Jurassic2ChatRequest request) { - return this.internalInvocation(request, Ai21Jurassic2ChatResponse.class); - } - - -} -// @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java index b391763227..71fd5b3496 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java @@ -15,28 +15,23 @@ */ package org.springframework.ai.bedrock.llama; -import java.util.List; - -import reactor.core.publisher.Flux; - -import org.springframework.ai.bedrock.MessageToPromptConverter; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.ModelDescription; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; + /** * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Llama chat - * generative. + * generative model. * * @author Christian Tzolov * @author Wei Jiang @@ -44,89 +39,101 @@ */ public class BedrockLlamaChatModel implements ChatModel, StreamingChatModel { - private final LlamaChatBedrockApi chatApi; + private final String modelId; + + private final BedrockConverseApi converseApi; private final BedrockLlamaChatOptions defaultOptions; - public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi) { - this(chatApi, + public BedrockLlamaChatModel(BedrockConverseApi converseApi) { + this(converseApi, BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build()); } - public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) { - Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null"); - Assert.notNull(options, "BedrockLlamaChatOptions must not be null"); + public BedrockLlamaChatModel(BedrockConverseApi converseApi, BedrockLlamaChatOptions options) { + this(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), converseApi, options); + } + + public BedrockLlamaChatModel(String modelId, BedrockConverseApi converseApi, BedrockLlamaChatOptions options) { + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(options, "BedrockLlamaChatOptions must not be null."); - this.chatApi = chatApi; + this.modelId = modelId; + this.converseApi = converseApi; this.defaultOptions = options; } @Override public ChatResponse call(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - var request = createRequest(prompt); + var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions); - LlamaChatResponse response = this.chatApi.chatCompletion(request); + ConverseResponse response = this.converseApi.converse(request); - return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( - ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); + return BedrockConverseApiUtils.convertConverseResponse(response); } @Override public Flux stream(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); - var request = createRequest(prompt); + var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions); - Flux fluxResponse = this.chatApi.chatCompletionStream(request); + Flux fluxResponse = this.converseApi.converseStream(request); - return fluxResponse.map(response -> { - String stopReason = response.stopReason() != null ? response.stopReason().name() : null; - return new ChatResponse(List.of(new Generation(response.generation()) - .withGenerationMetadata(ChatGenerationMetadata.from(stopReason, extractUsage(response))))); - }); + return fluxResponse.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output)); } - private Usage extractUsage(LlamaChatResponse response) { - return new Usage() { - - @Override - public Long getPromptTokens() { - return response.promptTokenCount().longValue(); - } - - @Override - public Long getGenerationTokens() { - return response.generationTokenCount().longValue(); - } - }; + @Override + public ChatOptions getDefaultOptions() { + return BedrockLlamaChatOptions.fromOptions(this.defaultOptions); } /** - * Accessible for testing. + * Llama models version. */ - LlamaChatRequest createRequest(Prompt prompt) { - - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); - - LlamaChatRequest request = LlamaChatRequest.builder(promptValue).build(); - - if (this.defaultOptions != null) { - request = ModelOptionsUtils.merge(request, this.defaultOptions, LlamaChatRequest.class); + public enum LlamaChatModel implements ModelDescription { + + /** + * meta.llama2-13b-chat-v1 + */ + LLAMA2_13B_CHAT_V1("meta.llama2-13b-chat-v1"), + + /** + * meta.llama2-70b-chat-v1 + */ + LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"), + + /** + * meta.llama3-8b-instruct-v1:0 + */ + LLAMA3_8B_INSTRUCT_V1("meta.llama3-8b-instruct-v1:0"), + + /** + * meta.llama3-70b-instruct-v1:0 + */ + LLAMA3_70B_INSTRUCT_V1("meta.llama3-70b-instruct-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; } - if (prompt.getOptions() != null) { - BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), - ChatOptions.class, BedrockLlamaChatOptions.class); - - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class); + LlamaChatModel(String value) { + this.id = value; } - return request; - } + @Override + public String getModelName() { + return this.id; + } - @Override - public ChatOptions getDefaultOptions() { - return BedrockLlamaChatOptions.fromOptions(this.defaultOptions); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java index 4d6c0a6e04..5753cf4d2d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatOptions.java @@ -23,7 +23,12 @@ import org.springframework.ai.chat.prompt.ChatOptions; /** + * Java {@link ChatOptions} for the Bedrock Cohere Command chat generative model chat + * options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + * * @author Christian Tzolov + * @author Wei Jiang */ @JsonInclude(Include.NON_NULL) public class BedrockLlamaChatOptions implements ChatOptions { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java deleted file mode 100644 index 16af9735ed..0000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.llama.api; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; -import org.springframework.ai.model.ModelDescription; - -import java.time.Duration; - -// @formatter:off -/** - * Java client for the Bedrock Llama chat model. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html - * - * @author Christian Tzolov - * @author Wei Jiang - * @since 0.8.0 - */ -public class LlamaChatBedrockApi extends - AbstractBedrockApi { - - /** - * Create a new LlamaChatBedrockApi instance using the default credentials provider chain, the default object - * mapper, default temperature and topP values. - * - * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. - * @param region The AWS region to use. - */ - public LlamaChatBedrockApi(String modelId, String region) { - super(modelId, region); - } - - /** - * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - */ - public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { - super(modelId, credentialsProvider, region, objectMapper); - } - - /** - * Create a new LlamaChatBedrockApi instance using the default credentials provider chain, the default object - * mapper, default temperature and topP values. - * - * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. - * @param region The AWS region to use. - * @param timeout The timeout to use. - */ - public LlamaChatBedrockApi(String modelId, String region, Duration timeout) { - super(modelId, region, timeout); - } - - /** - * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model. - * - * @param prompt The prompt to use for the chat. - * @param temperature The temperature value controls the randomness of the generated text. Use a lower value to - * decrease randomness in the response. - * @param topP The topP value controls the diversity of the generated text. Use a lower value to ignore less - * probable options. Set to 0 or 1.0 to disable. - * @param maxGenLen The maximum length of the generated text. - */ - @JsonInclude(Include.NON_NULL) - public record LlamaChatRequest( - @JsonProperty("prompt") String prompt, - @JsonProperty("temperature") Float temperature, - @JsonProperty("top_p") Float topP, - @JsonProperty("max_gen_len") Integer maxGenLen) { - - /** - * Create a new LlamaChatRequest builder. - * @param prompt compulsory prompt parameter. - * @return a new LlamaChatRequest builder. - */ - public static Builder builder(String prompt) { - return new Builder(prompt); - } - - public static class Builder { - private String prompt; - private Float temperature; - private Float topP; - private Integer maxGenLen; - - public Builder(String prompt) { - this.prompt = prompt; - } - - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder withTopP(Float topP) { - this.topP = topP; - return this; - } - - public Builder withMaxGenLen(Integer maxGenLen) { - this.maxGenLen = maxGenLen; - return this; - } - - public LlamaChatRequest build() { - return new LlamaChatRequest( - prompt, - temperature, - topP, - maxGenLen - ); - } - } - } - - /** - * LlamaChatResponse encapsulates the response parameters for the Meta Llama chat model. - * - * @param generation The generated text. - * @param promptTokenCount The number of tokens in the prompt. - * @param generationTokenCount The number of tokens in the response. - * @param stopReason The reason why the response stopped generating text. Possible values are: (1) stop – The model - * has finished generating text for the input prompt. (2) length – The length of the tokens for the generated text - * exceeds the value of max_gen_len in the call. The response is truncated to max_gen_len tokens. Consider - * increasing the value of max_gen_len and trying again. - */ - @JsonInclude(Include.NON_NULL) - public record LlamaChatResponse( - @JsonProperty("generation") String generation, - @JsonProperty("prompt_token_count") Integer promptTokenCount, - @JsonProperty("generation_token_count") Integer generationTokenCount, - @JsonProperty("stop_reason") StopReason stopReason, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - - /** - * The reason the response finished being generated. - */ - public enum StopReason { - /** - * The model has finished generating text for the input prompt. - */ - @JsonProperty("stop") STOP, - /** - * The response was truncated because of the response length you set. - */ - @JsonProperty("length") LENGTH - } - } - - /** - * Llama models version. - */ - public enum LlamaChatModel implements ModelDescription { - - /** - * meta.llama2-13b-chat-v1 - */ - LLAMA2_13B_CHAT_V1("meta.llama2-13b-chat-v1"), - - /** - * meta.llama2-70b-chat-v1 - */ - LLAMA2_70B_CHAT_V1("meta.llama2-70b-chat-v1"), - - /** - * meta.llama3-8b-instruct-v1:0 - */ - LLAMA3_8B_INSTRUCT_V1("meta.llama3-8b-instruct-v1:0"), - - /** - * meta.llama3-70b-instruct-v1:0 - */ - LLAMA3_70B_INSTRUCT_V1("meta.llama3-70b-instruct-v1:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - LlamaChatModel(String value) { - this.id = value; - } - - @Override - public String getModelName() { - return this.id; - } - } - - @Override - public LlamaChatResponse chatCompletion(LlamaChatRequest request) { - return this.internalInvocation(request, LlamaChatResponse.class); - } - - @Override - public Flux chatCompletionStream(LlamaChatRequest request) { - return this.internalInvocationStream(request, LlamaChatResponse.class); - } -} -// @formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java new file mode 100644 index 0000000000..61460bbd38 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java @@ -0,0 +1,297 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; +import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.StopReason; +import software.amazon.awssdk.services.bedrockruntime.model.Tool; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultStatus; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; +import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type; + +/** + * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Mistral chat + * generative model. + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockMistralChatModel extends AbstractFunctionCallSupport + implements ChatModel, StreamingChatModel { + + private final String modelId; + + private final BedrockConverseApi converseApi; + + private final BedrockMistralChatOptions defaultOptions; + + public BedrockMistralChatModel(BedrockConverseApi converseApi) { + this(converseApi, BedrockMistralChatOptions.builder().build()); + } + + public BedrockMistralChatModel(BedrockConverseApi converseApi, BedrockMistralChatOptions options) { + this(MistralChatModel.MISTRAL_LARGE.id(), converseApi, options); + } + + public BedrockMistralChatModel(String modelId, BedrockConverseApi converseApi, BedrockMistralChatOptions options) { + this(modelId, converseApi, options, null); + } + + public BedrockMistralChatModel(String modelId, BedrockConverseApi converseApi, BedrockMistralChatOptions options, + FunctionCallbackContext functionCallbackContext) { + super(functionCallbackContext); + + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(options, "BedrockMistralChatOptions must not be null."); + + this.modelId = modelId; + this.converseApi = converseApi; + this.defaultOptions = options; + } + + @Override + public ChatResponse call(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); + + var request = createBedrockConverseRequest(prompt); + + return this.callWithFunctionSupport(request); + } + + @Override + public Flux stream(Prompt prompt) { + Assert.notNull(prompt, "Prompt must not be null."); + + var request = createBedrockConverseRequest(prompt); + + return converseApi.converseStream(request); + } + + private BedrockConverseRequest createBedrockConverseRequest(Prompt prompt) { + var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions); + + ToolConfiguration toolConfiguration = createToolConfiguration(prompt); + + return BedrockConverseRequest.from(request).withToolConfiguration(toolConfiguration).build(); + } + + private ToolConfiguration createToolConfiguration(Prompt prompt) { + Set functionsForThisRequest = new HashSet<>(); + + if (this.defaultOptions != null) { + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + functionsForThisRequest.addAll(promptEnabledFunctions); + } + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + BedrockMistralChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, BedrockMistralChatOptions.class); + + Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, + IS_RUNTIME_CALL); + functionsForThisRequest.addAll(defaultEnabledFunctions); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + return ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build(); + } + + return null; + } + + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var description = functionCallback.getDescription(); + var name = functionCallback.getName(); + String inputSchema = functionCallback.getInputTypeSchema(); + + return Tool.builder() + .toolSpec(ToolSpecification.builder() + .name(name) + .description(description) + .inputSchema(ToolInputSchema.builder() + .json(BedrockConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema))) + .build()) + .build()) + .build(); + }).toList(); + } + + @Override + public ChatOptions getDefaultOptions() { + return defaultOptions; + } + + @Override + protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest, + Message responseMessage, List conversationHistory) { + List toolToUseList = responseMessage.content() + .stream() + .filter(content -> content.type() == Type.TOOL_USE) + .map(content -> content.toolUse()) + .toList(); + + List toolResults = new ArrayList<>(); + + for (ToolUseBlock toolToUse : toolToUseList) { + var functionCallId = toolToUse.toolUseId(); + var functionName = toolToUse.name(); + var functionArguments = toolToUse.input().unwrap(); + + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName) + .call(ModelOptionsUtils.toJsonString(functionArguments)); + + toolResults.add(ToolResultBlock.builder() + .toolUseId(functionCallId) + .status(ToolResultStatus.SUCCESS) + .content(ToolResultContentBlock.builder().text(functionResponse).build()) + .build()); + } + + // Add the function response to the conversation. + Message toolResultMessage = Message.builder() + .content(toolResults.stream().map(toolResult -> ContentBlock.fromToolResult(toolResult)).toList()) + .role(ConversationRole.USER) + .build(); + conversationHistory.add(toolResultMessage); + + // Recursively call chatCompletionWithTools until the model doesn't call a + // functions anymore. + return BedrockConverseRequest.from(previousRequest).withMessages(conversationHistory).build(); + } + + @Override + protected List doGetUserMessages(BedrockConverseRequest request) { + return request.messages(); + } + + @Override + protected Message doGetToolResponseMessage(ChatResponse response) { + Generation result = response.getResult(); + + var metadata = (BedrockConverseChatGenerationMetadata) result.getMetadata(); + + return metadata.getMessage(); + } + + @Override + protected ChatResponse doChatCompletion(BedrockConverseRequest request) { + return converseApi.converse(request); + } + + @Override + protected Flux doChatCompletionStream(BedrockConverseRequest request) { + throw new UnsupportedOperationException("Streaming function calling is not supported."); + } + + @Override + protected boolean isToolFunctionCall(ChatResponse response) { + Generation result = response.getResult(); + if (result == null) { + return false; + } + + return StopReason.fromValue(result.getMetadata().getFinishReason()) == StopReason.TOOL_USE; + } + + /** + * Mistral models version. + */ + public enum MistralChatModel implements ModelDescription { + + /** + * mistral.mistral-7b-instruct-v0:2 + */ + MISTRAL_7B_INSTRUCT("mistral.mistral-7b-instruct-v0:2"), + + /** + * mistral.mixtral-8x7b-instruct-v0:1 + */ + MISTRAL_8X7B_INSTRUCT("mistral.mixtral-8x7b-instruct-v0:1"), + + /** + * mistral.mistral-large-2402-v1:0 + */ + MISTRAL_LARGE("mistral.mistral-large-2402-v1:0"), + + /** + * mistral.mistral-small-2402-v1:0 + */ + MISTRAL_SMALL("mistral.mistral-small-2402-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; + } + + MistralChatModel(String value) { + this.id = value; + } + + @Override + public String getModelName() { + return this.id; + } + + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java new file mode 100644 index 0000000000..d2d7f9f78f --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java @@ -0,0 +1,256 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonInclude.Include; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; + +/** + * Java {@link ChatOptions} for the Bedrock Mistral chat generative model chat options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-chat-completion.html + * + * @author Wei Jiang + * @since 1.0.0 + */ +@JsonInclude(Include.NON_NULL) +public class BedrockMistralChatOptions implements ChatOptions, FunctionCallingOptions { + + /** + * The temperature value controls the randomness of the generated text. Use a lower + * value to decrease randomness in the response. + */ + private @JsonProperty("temperature") Float temperature; + + /** + * (optional) The maximum cumulative probability of tokens to consider when sampling. + * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers + * the smallest set of tokens whose probability sum is at least topP. + */ + private @JsonProperty("top_p") Float topP; + + /** + * (optional) Specify the number of token choices the generative uses to generate the + * next token. + */ + private @JsonProperty("top_k") Integer topK; + + /** + * (optional) Specify the maximum number of tokens to use in the generated response. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + + /** + * (optional) Configure up to four sequences that the generative recognizes. After a + * stop sequence, the generative stops generating further tokens. The returned text + * doesn't contain the stop sequence. + */ + private @JsonProperty("stop") List stopSequences; + + /** + * (optional) Specifies how functions are called. If set to none the model won't call + * a function and will generate a message instead. If set to auto the model can choose + * to either generate a message or call a function. If set to any the model is forced + * to call a function. + */ + private @JsonProperty("tool_choice") String toolChoice; + + /** + * Tool Function Callbacks to register with the ChatModel. For Prompt Options the + * functionCallbacks are automatically enabled for the duration of the prompt + * execution. For Default Options the functionCallbacks are registered but disabled by + * default. Use the enableFunctions to set the functions from the registry to be used + * by the ChatModel chat completion requests. + */ + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. Functions with those names must exist in the + * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions + * are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat + * completion requests. This could impact the token count and the billing. If the + * functions is set in a prompt options, then the enabled functions are only active + * for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final BedrockMistralChatOptions options = new BedrockMistralChatOptions(); + + public Builder withTemperature(Float temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Float topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public Builder withToolChoice(String toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + + public BedrockMistralChatOptions build() { + return this.options; + } + + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + public String getToolChoice() { + return toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + Assert.notNull(functions, "Function must not be null"); + this.functions = functions; + } + + public static BedrockMistralChatOptions fromOptions(BedrockMistralChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .withToolChoice(fromOptions.toolChoice) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) + .build(); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index b144a2a10d..f6026d02df 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -15,134 +15,120 @@ */ package org.springframework.ai.bedrock.titan; -import java.util.List; - import reactor.core.publisher.Flux; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; -import org.springframework.ai.bedrock.MessageToPromptConverter; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.ModelDescription; import org.springframework.util.Assert; /** + * Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Titan chat + * generative model. + * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockTitanChatModel implements ChatModel, StreamingChatModel { - private final TitanChatBedrockApi chatApi; + private final String modelId; + + private final BedrockConverseApi converseApi; private final BedrockTitanChatOptions defaultOptions; - public BedrockTitanChatModel(TitanChatBedrockApi chatApi) { - this(chatApi, BedrockTitanChatOptions.builder().withTemperature(0.8f).build()); + public BedrockTitanChatModel(BedrockConverseApi converseApi) { + this(converseApi, BedrockTitanChatOptions.builder().withTemperature(0.8f).build()); + } + + public BedrockTitanChatModel(BedrockConverseApi converseApi, BedrockTitanChatOptions defaultOptions) { + this(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), converseApi, defaultOptions); } - public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) { - Assert.notNull(chatApi, "ChatApi must not be null"); - Assert.notNull(defaultOptions, "DefaultOptions must not be null"); - this.chatApi = chatApi; + public BedrockTitanChatModel(String modelId, BedrockConverseApi converseApi, + BedrockTitanChatOptions defaultOptions) { + Assert.notNull(modelId, "modelId must not be null."); + Assert.notNull(converseApi, "BedrockConverseApi must not be null."); + Assert.notNull(defaultOptions, "BedrockTitanChatOptions must not be null"); + + this.modelId = modelId; + this.converseApi = converseApi; this.defaultOptions = defaultOptions; } @Override public ChatResponse call(Prompt prompt) { - TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); - List generations = response.results().stream().map(result -> { - return new Generation(result.outputText()); - }).toList(); + Assert.notNull(prompt, "Prompt must not be null."); + + var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions); + + ConverseResponse response = this.converseApi.converse(request); - return new ChatResponse(generations); + return BedrockConverseApiUtils.convertConverseResponse(response); } @Override public Flux stream(Prompt prompt) { - return this.chatApi.chatCompletionStream(this.createRequest(prompt)).map(chunk -> { - - Generation generation = new Generation(chunk.outputText()); - - if (chunk.amazonBedrockInvocationMetrics() != null) { - String completionReason = chunk.completionReason().name(); - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(completionReason, chunk.amazonBedrockInvocationMetrics())); - } - else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount() != null) { - String completionReason = chunk.completionReason().name(); - generation = generation - .withGenerationMetadata(ChatGenerationMetadata.from(completionReason, extractUsage(chunk))); - - } - return new ChatResponse(List.of(generation)); - }); - } - - /** - * Test access. - */ - TitanChatRequest createRequest(Prompt prompt) { - final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); + Assert.notNull(prompt, "Prompt must not be null."); - var requestBuilder = TitanChatRequest.builder(promptValue); + var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions); - if (this.defaultOptions != null) { - requestBuilder = update(requestBuilder, this.defaultOptions); - } - - if (prompt.getOptions() != null) { - BedrockTitanChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), - ChatOptions.class, BedrockTitanChatOptions.class); - - requestBuilder = update(requestBuilder, updatedRuntimeOptions); - } + Flux fluxResponse = this.converseApi.converseStream(request); - return requestBuilder.build(); + return fluxResponse.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output)); } - private TitanChatRequest.Builder update(TitanChatRequest.Builder builder, BedrockTitanChatOptions options) { - if (options.getTemperature() != null) { - builder.withTemperature(options.getTemperature()); - } - if (options.getTopP() != null) { - builder.withTopP(options.getTopP()); - } - if (options.getMaxTokenCount() != null) { - builder.withMaxTokenCount(options.getMaxTokenCount()); - } - if (options.getStopSequences() != null) { - builder.withStopSequences(options.getStopSequences()); - } - return builder; + @Override + public ChatOptions getDefaultOptions() { + return BedrockTitanChatOptions.fromOptions(this.defaultOptions); } - private Usage extractUsage(TitanChatResponseChunk response) { - return new Usage() { + /** + * Titan models version. + */ + public enum TitanChatModel implements ModelDescription { + + /** + * amazon.titan-text-lite-v1 + */ + TITAN_TEXT_LITE_V1("amazon.titan-text-lite-v1"), + + /** + * amazon.titan-text-express-v1 + */ + TITAN_TEXT_EXPRESS_V1("amazon.titan-text-express-v1"), + + /** + * amazon.titan-text-premier-v1:0 + */ + TITAN_TEXT_PREMIER_V1("amazon.titan-text-premier-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; + } - @Override - public Long getPromptTokens() { - return response.inputTextTokenCount().longValue(); - } + TitanChatModel(String value) { + this.id = value; + } - @Override - public Long getGenerationTokens() { - return response.totalOutputTextTokenCount().longValue(); - } - }; - } + @Override + public String getModelName() { + return this.id; + } - @Override - public ChatOptions getDefaultOptions() { - return BedrockTitanChatOptions.fromOptions(this.defaultOptions); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java index d53126a0b7..6c07f26475 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatOptions.java @@ -17,6 +17,7 @@ import java.util.List; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -25,33 +26,47 @@ import com.fasterxml.jackson.annotation.JsonProperty; /** + * Java {@link ChatOptions} for the Bedrock Titan chat generative model chat options. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) public class BedrockTitanChatOptions implements ChatOptions { - // @formatter:off /** - * The temperature value controls the randomness of the generated text. + * The Titan chat model text generation config. */ - private @JsonProperty("temperature") Float temperature; + private @JsonProperty("textGenerationConfig") TextGenerationConfig textGenerationConfig = new TextGenerationConfig(); - /** - * The topP value controls the diversity of the generated text. Use a lower value to ignore less probable options. - */ - private @JsonProperty("topP") Float topP; + @JsonInclude(Include.NON_NULL) + public static class TextGenerationConfig { - /** - * Maximum number of tokens to generate. - */ - private @JsonProperty("maxTokenCount") Integer maxTokenCount; + // @formatter:off + /** + * The temperature value controls the randomness of the generated text. + */ + private @JsonProperty(value = "temperature") Float temperature; - /** - * A list of tokens that the model should stop generating after. - */ - private @JsonProperty("stopSequences") List stopSequences; - // @formatter:on + /** + * The topP value controls the diversity of the generated text. Use a lower value to ignore less probable options. + */ + private @JsonProperty("topP") Float topP; + + /** + * Maximum number of tokens to generate. + */ + private @JsonProperty("maxTokenCount") Integer maxTokenCount; + + /** + * A list of tokens that the model should stop generating after. + */ + private @JsonProperty("stopSequences") List stopSequences; + // @formatter:on + + } public static Builder builder() { return new Builder(); @@ -62,22 +77,22 @@ public static class Builder { private BedrockTitanChatOptions options = new BedrockTitanChatOptions(); public Builder withTemperature(Float temperature) { - this.options.temperature = temperature; + this.options.textGenerationConfig.temperature = temperature; return this; } public Builder withTopP(Float topP) { - this.options.topP = topP; + this.options.textGenerationConfig.topP = topP; return this; } public Builder withMaxTokenCount(Integer maxTokenCount) { - this.options.maxTokenCount = maxTokenCount; + this.options.textGenerationConfig.maxTokenCount = maxTokenCount; return this; } public Builder withStopSequences(List stopSequences) { - this.options.stopSequences = stopSequences; + this.options.textGenerationConfig.stopSequences = stopSequences; return this; } @@ -87,39 +102,44 @@ public BedrockTitanChatOptions build() { } + @JsonIgnore public Float getTemperature() { - return temperature; + return this.textGenerationConfig.temperature; } public void setTemperature(Float temperature) { - this.temperature = temperature; + this.textGenerationConfig.temperature = temperature; } + @JsonIgnore public Float getTopP() { - return topP; + return this.textGenerationConfig.topP; } public void setTopP(Float topP) { - this.topP = topP; + this.textGenerationConfig.topP = topP; } public Integer getMaxTokenCount() { - return maxTokenCount; + return this.textGenerationConfig.maxTokenCount; } + @JsonIgnore public void setMaxTokenCount(Integer maxTokenCount) { - this.maxTokenCount = maxTokenCount; + this.textGenerationConfig.maxTokenCount = maxTokenCount; } + @JsonIgnore public List getStopSequences() { - return stopSequences; + return this.textGenerationConfig.stopSequences; } public void setStopSequences(List stopSequences) { - this.stopSequences = stopSequences; + this.textGenerationConfig.stopSequences = stopSequences; } @Override + @JsonIgnore public Integer getTopK() { throw new UnsupportedOperationException("Bedrock Titan Chat does not support the 'TopK' option."); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java index e3089eec5d..979e4c82b6 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java @@ -31,6 +31,8 @@ import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; /** @@ -50,6 +52,11 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel { private final TitanEmbeddingBedrockApi embeddingApi; + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + public enum InputType { TEXT, IMAGE @@ -62,7 +69,15 @@ public enum InputType { private InputType inputType = InputType.TEXT; public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) { + this(titanEmbeddingBedrockApi, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi, RetryTemplate retryTemplate) { + Assert.notNull(titanEmbeddingBedrockApi, "TitanEmbeddingBedrockApi must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.embeddingApi = titanEmbeddingBedrockApi; + this.retryTemplate = retryTemplate; } /** @@ -87,17 +102,19 @@ public EmbeddingResponse call(EmbeddingRequest request) { "Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } - List> embeddingList = new ArrayList<>(); - for (String inputContent : request.getInstructions()) { - var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions()); - TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); - embeddingList.add(response.embedding()); - } - var indexCounter = new AtomicInteger(0); - List embeddings = embeddingList.stream() - .map(e -> new Embedding(e, indexCounter.getAndIncrement())) - .toList(); - return new EmbeddingResponse(embeddings); + return this.retryTemplate.execute(ctx -> { + List> embeddingList = new ArrayList<>(); + for (String inputContent : request.getInstructions()) { + var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions()); + TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); + embeddingList.add(response.embedding()); + } + var indexCounter = new AtomicInteger(0); + List embeddings = embeddingList.stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); + }); } private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java deleted file mode 100644 index ce1842adf3..0000000000 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.titan.api; - -import java.time.Duration; -import java.util.List; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.api.AbstractBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.CompletionReason; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; -import org.springframework.ai.model.ModelDescription; - -/** - * Java client for the Bedrock Titan chat model. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html - *

- * https://docs.aws.amazon.com/bedrock/latest/userguide/titan-text-models.html - * - * @author Christian Tzolov - * @author Wei Jiang - * @since 0.8.0 - */ -// @formatter:off -public class TitanChatBedrockApi extends - AbstractBedrockApi { - - /** - * Create a new TitanChatBedrockApi instance using the default credentials provider chain, the default object mapper. - * - * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models. - * @param region The AWS region to use. - */ - public TitanChatBedrockApi(String modelId, String region) { - super(modelId, region); - } - - /** - * Create a new TitanChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - */ - public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { - super(modelId, credentialsProvider, region, objectMapper); - } - - /** - * Create a new TitanChatBedrockApi instance using the default credentials provider chain, the default object mapper. - * - * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models. - * @param region The AWS region to use. - * @param timeout The timeout to use. - */ - public TitanChatBedrockApi(String modelId, String region, Duration timeout) { - super(modelId, region, timeout); - } - - /** - * Create a new TitanChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * Create a new TitanChatBedrockApi instance using the provided credentials provider, region and object mapper. - * - * @param modelId The model id to use. See the {@link TitanChatModel} for the supported models. - * @param credentialsProvider The credentials provider to connect to AWS. - * @param region The AWS region to use. - * @param objectMapper The object mapper to use for JSON serialization and deserialization. - * @param timeout The timeout to use. - */ - public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, - ObjectMapper objectMapper, Duration timeout) { - super(modelId, credentialsProvider, region, objectMapper, timeout); - } - - /** - * TitanChatRequest encapsulates the request parameters for the Titan chat model. - * - * @param inputText The prompt to use for the chat. - * @param textGenerationConfig The text generation configuration. - */ - @JsonInclude(Include.NON_NULL) - public record TitanChatRequest( - @JsonProperty("inputText") String inputText, - @JsonProperty("textGenerationConfig") TextGenerationConfig textGenerationConfig) { - - /** - * Titan request text generation configuration. - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html - * - * @param temperature The temperature value controls the randomness of the generated text. - * @param topP The topP value controls the diversity of the generated text. Use a lower value to ignore less - * probable options. - * @param maxTokenCount The maximum number of tokens to generate. - * @param stopSequences A list of sequences to stop the generation at. Specify character sequences to indicate - * where the model should stop. Use the | (pipe) character to separate different sequences (maximum 20 - * characters). - */ - @JsonInclude(Include.NON_NULL) - public record TextGenerationConfig( - @JsonProperty("temperature") Float temperature, - @JsonProperty("topP") Float topP, - @JsonProperty("maxTokenCount") Integer maxTokenCount, - @JsonProperty("stopSequences") List stopSequences) { - } - - /** - * Create a new TitanChatRequest builder. - * @param inputText The prompt to use for the chat. - * @return A new TitanChatRequest builder. - */ - public static Builder builder(String inputText) { - return new Builder(inputText); - } - - public static class Builder { - private final String inputText; - private Float temperature; - private Float topP; - private Integer maxTokenCount; - private List stopSequences; - - public Builder(String inputText) { - this.inputText = inputText; - } - - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - - public Builder withTopP(Float topP) { - this.topP = topP; - return this; - } - - public Builder withMaxTokenCount(Integer maxTokenCount) { - this.maxTokenCount = maxTokenCount; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this; - } - - public TitanChatRequest build() { - - if (this.temperature == null && this.topP == null && this.maxTokenCount == null - && this.stopSequences == null) { - return new TitanChatRequest(this.inputText, null); - } else { - return new TitanChatRequest(this.inputText, - new TextGenerationConfig( - this.temperature, - this.topP, - this.maxTokenCount, - this.stopSequences - )); - } - } - } - } - - /** - * TitanChatResponse encapsulates the response parameters for the Titan chat model. - * - * @param inputTextTokenCount The number of tokens in the input text. - * @param results The list of generated responses. - */ - @JsonInclude(Include.NON_NULL) - public record TitanChatResponse( - @JsonProperty("inputTextTokenCount") Integer inputTextTokenCount, - @JsonProperty("results") List results) { - - /** - * Titan response result. - * - * @param tokenCount The number of tokens in the generated text. - * @param outputText The generated text. - * @param completionReason The reason the response finished being generated. - */ - @JsonInclude(Include.NON_NULL) - public record Result( - @JsonProperty("tokenCount") Integer tokenCount, - @JsonProperty("outputText") String outputText, - @JsonProperty("completionReason") CompletionReason completionReason) { - } - - /** - * The reason the response finished being generated. - */ - public enum CompletionReason { - /** - * The response was fully generated. - */ - FINISH, - - /** - * The response was truncated because of the response length you set. - */ - LENGTH, - - /** - * The response was truncated because of restrictions. - */ - CONTENT_FILTERED - } - } - - /** - * Titan chat model streaming response. - * - * @param outputText The generated text in this chunk. - * @param index The index of the chunk in the streaming response. - * @param inputTextTokenCount The number of tokens in the prompt. - * @param totalOutputTextTokenCount The number of tokens in the response. - * @param completionReason The reason the response finished being generated. - */ - @JsonInclude(Include.NON_NULL) - public record TitanChatResponseChunk( - @JsonProperty("outputText") String outputText, - @JsonProperty("index") Integer index, - @JsonProperty("inputTextTokenCount") Integer inputTextTokenCount, - @JsonProperty("totalOutputTextTokenCount") Integer totalOutputTextTokenCount, - @JsonProperty("completionReason") CompletionReason completionReason, - @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { - } - - /** - * Titan models version. - */ - public enum TitanChatModel implements ModelDescription { - - /** - * amazon.titan-text-lite-v1 - */ - TITAN_TEXT_LITE_V1("amazon.titan-text-lite-v1"), - - /** - * amazon.titan-text-express-v1 - */ - TITAN_TEXT_EXPRESS_V1("amazon.titan-text-express-v1"), - - /** - * amazon.titan-text-premier-v1:0 - */ - TITAN_TEXT_PREMIER_V1("amazon.titan-text-premier-v1:0"); - - private final String id; - - /** - * @return The model id. - */ - public String id() { - return id; - } - - TitanChatModel(String value) { - this.id = value; - } - - @Override - public String getModelName() { - return this.id; - } - } - - @Override - public TitanChatResponse chatCompletion(TitanChatRequest request) { - return this.internalInvocation(request, TitanChatResponse.class); - } - - @Override - public Flux chatCompletionStream(TitanChatRequest request) { - return this.internalInvocationStream(request, TitanChatResponseChunk.class); - } -} -// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index 016ec4306b..9a5d180957 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -24,6 +24,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import org.springframework.ai.bedrock.api.AbstractBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; @@ -81,6 +83,20 @@ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentia super(modelId, credentialsProvider, region, objectMapper, timeout); } + /** + * Create a new TitanEmbeddingBedrockApi instance. + * + * @param modelId The model id to use. + * @param bedrockRuntimeClient The AWS BedrockRuntimeClient instance. + * @param bedrockRuntimeAsyncClient The AWS BedrockRuntimeAsyncClient instance. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + */ + public TitanEmbeddingBedrockApi(String model, BedrockRuntimeClient bedrockRuntimeClient, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ObjectMapper objectMapper) { + super(model, bedrockRuntimeClient, bedrockRuntimeAsyncClient, objectMapper); + } + + /** * Titan Embedding request parameters. * @@ -157,7 +173,7 @@ public enum TitanEmbeddingModel { /** * amazon.titan-embed-text-v2 */ - TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; + TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0"); private final String id; diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java new file mode 100644 index 0000000000..78f41f210a --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, Unit.C); + } + +} \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java index b2b09fc8ec..5e18efc9b5 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java @@ -21,7 +21,6 @@ import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -30,12 +29,13 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -198,19 +198,47 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void chatResponseUsage() { + Prompt prompt = new Prompt("Who are you?"); + + ChatResponse response = chatModel.call(prompt); + + Usage usage = response.getMetadata().getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isGreaterThan(1); + assertThat(usage.getGenerationTokens()).isGreaterThan(1); + } + + @Test + void chatOptions() { + AnthropicChatOptions options = AnthropicChatOptions.builder() + .withTemperature(0.5F) + .withMaxTokens(100) + .withTopK(10) + .withTopP(0.5F) + .withStopSequences(List.of("stop sequences")) + .build(); + + Prompt prompt = new Prompt("Who are you?", options); + ChatResponse response = chatModel.call(prompt); + String content = response.getResult().getOutput().getContent(); + + assertThat(content).isNotNull(); + } + @SpringBootConfiguration public static class TestConfiguration { @Bean - public AnthropicChatBedrockApi anthropicApi() { - return new AnthropicChatBedrockApi(AnthropicChatBedrockApi.AnthropicChatModel.CLAUDE_V2.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); } @Bean - public BedrockAnthropicChatModel anthropicChatModel(AnthropicChatBedrockApi anthropicApi) { - return new BedrockAnthropicChatModel(anthropicApi); + public BedrockAnthropicChatModel anthropicChatModel(BedrockConverseApi converseApi) { + return new BedrockAnthropicChatModel(converseApi); } } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java deleted file mode 100644 index c8b5cbe859..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.anthropic; - -import java.time.Duration; -import java.util.List; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; -import org.springframework.ai.chat.prompt.Prompt; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -public class BedrockAnthropicCreateRequestTests { - - private AnthropicChatBedrockApi anthropicChatApi = new AnthropicChatBedrockApi(AnthropicChatModel.CLAUDE_V2.id(), - Region.US_EAST_1.id(), Duration.ofMillis(1000L)); - - @Test - public void createRequestWithChatOptions() { - - var client = new BedrockAnthropicChatModel(anthropicChatApi, - AnthropicChatOptions.builder() - .withTemperature(66.6f) - .withTopK(66) - .withTopP(0.66f) - .withMaxTokensToSample(666) - .withAnthropicVersion("X.Y.Z") - .withStopSequences(List.of("stop1", "stop2")) - .build()); - - var request = client.createRequest(new Prompt("Test message content")); - - assertThat(request.prompt()).isNotEmpty(); - assertThat(request.temperature()).isEqualTo(66.6f); - assertThat(request.topK()).isEqualTo(66); - assertThat(request.topP()).isEqualTo(0.66f); - assertThat(request.maxTokensToSample()).isEqualTo(666); - assertThat(request.anthropicVersion()).isEqualTo("X.Y.Z"); - assertThat(request.stopSequences()).containsExactly("stop1", "stop2"); - - request = client.createRequest(new Prompt("Test message content", - AnthropicChatOptions.builder() - .withTemperature(99.9f) - .withTopP(0.99f) - .withMaxTokensToSample(999) - .withAnthropicVersion("zzz") - .withStopSequences(List.of("stop3", "stop4")) - .build() - - )); - - assertThat(request.prompt()).isNotEmpty(); - assertThat(request.temperature()).isEqualTo(99.9f); - assertThat(request.topK()).as("unchanged from the default options").isEqualTo(66); - assertThat(request.topP()).isEqualTo(0.99f); - assertThat(request.maxTokensToSample()).isEqualTo(999); - assertThat(request.anthropicVersion()).isEqualTo("zzz"); - assertThat(request.stopSequences()).containsExactly("stop3", "stop4"); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java deleted file mode 100644 index 334efa48ff..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.anthropic.api; - -import java.time.Duration; -import java.util.List; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; - -import static org.assertj.core.api.Assertions.assertThat;; - -/** - * @author Christian Tzolov - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class AnthropicChatBedrockApiIT { - - private final Logger logger = LoggerFactory.getLogger(AnthropicChatBedrockApiIT.class); - - private AnthropicChatBedrockApi anthropicChatApi = new AnthropicChatBedrockApi(AnthropicChatModel.CLAUDE_V2.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), - Duration.ofMinutes(2)); - - @Test - public void chatCompletion() { - - AnthropicChatRequest request = AnthropicChatRequest - .builder(String.format(AnthropicChatBedrockApi.PROMPT_TEMPLATE, "Name 3 famous pirates")) - .withTemperature(0.8f) - .withMaxTokensToSample(300) - .withTopK(10) - .build(); - - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); - - System.out.println(response.completion()); - assertThat(response).isNotNull(); - assertThat(response.completion()).isNotEmpty(); - assertThat(response.completion()).contains("Blackbeard"); - assertThat(response.stopReason()).isEqualTo("stop_sequence"); - assertThat(response.stop()).isEqualTo("\n\nHuman:"); - assertThat(response.amazonBedrockInvocationMetrics()).isNull(); - - logger.info("" + response); - } - - @Test - public void chatCompletionStream() { - - AnthropicChatRequest request = AnthropicChatRequest - .builder(String.format(AnthropicChatBedrockApi.PROMPT_TEMPLATE, "Name 3 famous pirates")) - .withTemperature(0.8f) - .withMaxTokensToSample(300) - .withTopK(10) - .withStopSequences(List.of("\n\nHuman:")) - .build(); - - Flux responseStream = anthropicChatApi.chatCompletionStream(request); - - List responses = responseStream.collectList().block(); - assertThat(responses).isNotNull(); - assertThat(responses).hasSizeGreaterThan(10); - assertThat(responses.stream().map(AnthropicChatResponse::completion).collect(Collectors.joining())) - .contains("Blackbeard"); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java index 7e31a8f017..3540d6a705 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java @@ -17,12 +17,12 @@ import java.io.IOException; import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -31,19 +31,22 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; +import org.springframework.ai.bedrock.MockWeatherService; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -208,7 +211,7 @@ void multiModalityTest() throws IOException { var imageData = new ClassPathResource("/test.png"); - var userMessage = new UserMessage("Explain what do you see o this picture?", + var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); var response = chatModel.call(new Prompt(List.of(userMessage))); @@ -217,19 +220,71 @@ void multiModalityTest() throws IOException { assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket"); } + @Test + void chatResponseUsage() { + Prompt prompt = new Prompt("Who are you?"); + + ChatResponse response = chatModel.call(prompt); + + Usage usage = response.getMetadata().getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isGreaterThan(1); + assertThat(usage.getGenerationTokens()).isGreaterThan(1); + } + + @Test + void chatOptions() { + Anthropic3ChatOptions options = Anthropic3ChatOptions.builder() + .withTemperature(0.5F) + .withMaxTokens(100) + .withTopK(10) + .withTopP(0.5F) + .withStopSequences(List.of("stop sequences")) + .build(); + + Prompt prompt = new Prompt("Who are you?", options); + ChatResponse response = chatModel.call(prompt); + String content = response.getResult().getOutput().getContent(); + + assertThat(content).isNotNull(); + } + + @Test + void functionCallTest() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = Anthropic3ChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15"); + } + @SpringBootConfiguration public static class TestConfiguration { @Bean - public Anthropic3ChatBedrockApi anthropicApi() { - return new Anthropic3ChatBedrockApi(Anthropic3ChatBedrockApi.AnthropicChatModel.CLAUDE_V3_SONNET.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), - Duration.ofMinutes(5)); + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), + Duration.ofMinutes(2)); } @Bean - public BedrockAnthropic3ChatModel anthropicChatModel(Anthropic3ChatBedrockApi anthropicApi) { - return new BedrockAnthropic3ChatModel(anthropicApi); + public BedrockAnthropic3ChatModel anthropicChatModel(BedrockConverseApi converseApi) { + return new BedrockAnthropic3ChatModel(converseApi); } } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java deleted file mode 100644 index 31486f9e93..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3CreateRequestTests.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.anthropic3; - -import org.junit.jupiter.api.Test; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; -import org.springframework.ai.chat.prompt.Prompt; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -public class BedrockAnthropic3CreateRequestTests { - - private Anthropic3ChatBedrockApi anthropicChatApi = new Anthropic3ChatBedrockApi(AnthropicChatModel.CLAUDE_V2.id(), - Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); - - @Test - public void createRequestWithChatOptions() { - - var client = new BedrockAnthropic3ChatModel(anthropicChatApi, - Anthropic3ChatOptions.builder() - .withTemperature(66.6f) - .withTopK(66) - .withTopP(0.66f) - .withMaxTokens(666) - .withAnthropicVersion("X.Y.Z") - .withStopSequences(List.of("stop1", "stop2")) - .build()); - - var request = client.createRequest(new Prompt("Test message content")); - - assertThat(request.messages()).isNotEmpty(); - assertThat(request.temperature()).isEqualTo(66.6f); - assertThat(request.topK()).isEqualTo(66); - assertThat(request.topP()).isEqualTo(0.66f); - assertThat(request.maxTokens()).isEqualTo(666); - assertThat(request.anthropicVersion()).isEqualTo("X.Y.Z"); - assertThat(request.stopSequences()).containsExactly("stop1", "stop2"); - - request = client.createRequest(new Prompt("Test message content", - Anthropic3ChatOptions.builder() - .withTemperature(99.9f) - .withTopP(0.99f) - .withMaxTokens(999) - .withAnthropicVersion("zzz") - .withStopSequences(List.of("stop3", "stop4")) - .build() - - )); - - assertThat(request.messages()).isNotEmpty(); - assertThat(request.temperature()).isEqualTo(99.9f); - assertThat(request.topK()).as("unchanged from the default options").isEqualTo(66); - assertThat(request.topP()).isEqualTo(0.99f); - assertThat(request.maxTokens()).isEqualTo(999); - assertThat(request.anthropicVersion()).isEqualTo("zzz"); - assertThat(request.stopSequences()).containsExactly("stop3", "stop4"); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java deleted file mode 100644 index 15ab3dd0f5..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.anthropic3.api; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import java.time.Duration; -import java.util.List; -import java.util.stream.Collectors; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; - -/** - * @author Ben Middleton - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class Anthropic3ChatBedrockApiIT { - - private final Logger logger = LoggerFactory.getLogger(Anthropic3ChatBedrockApiIT.class); - - private Anthropic3ChatBedrockApi anthropicChatApi = new Anthropic3ChatBedrockApi( - AnthropicChatModel.CLAUDE_INSTANT_V1.id(), EnvironmentVariableCredentialsProvider.create(), - Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); - - @Test - public void chatCompletion() { - - MediaContent anthropicMessage = new MediaContent("Name 3 famous pirates"); - ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage(List.of(anthropicMessage), Role.USER); - AnthropicChatRequest request = AnthropicChatRequest.builder(List.of(chatCompletionMessage)) - .withTemperature(0.8f) - .withMaxTokens(300) - .withTopK(10) - .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) - .build(); - - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); - - logger.info("" + response.content()); - - assertThat(response).isNotNull(); - assertThat(response.content().get(0).text()).isNotEmpty(); - assertThat(response.content().get(0).text()).contains("Blackbeard"); - assertThat(response.stopReason()).isEqualTo("end_turn"); - assertThat(response.stopSequence()).isNull(); - assertThat(response.usage().inputTokens()).isGreaterThan(10); - assertThat(response.usage().outputTokens()).isGreaterThan(100); - - logger.info("" + response); - } - - @Test - public void chatMultiCompletion() { - - MediaContent anthropicInitialMessage = new MediaContent("Name 3 famous pirates"); - ChatCompletionMessage chatCompletionInitialMessage = new ChatCompletionMessage(List.of(anthropicInitialMessage), - Role.USER); - - MediaContent anthropicAssistantMessage = new MediaContent( - "Here are 3 famous pirates: Blackbeard, Calico Jack, Henry Morgan"); - ChatCompletionMessage chatCompletionAssistantMessage = new ChatCompletionMessage( - List.of(anthropicAssistantMessage), Role.ASSISTANT); - - MediaContent anthropicFollowupMessage = new MediaContent("Why are they famous?"); - ChatCompletionMessage chatCompletionFollowupMessage = new ChatCompletionMessage( - List.of(anthropicFollowupMessage), Role.USER); - - AnthropicChatRequest request = AnthropicChatRequest - .builder(List.of(chatCompletionInitialMessage, chatCompletionAssistantMessage, - chatCompletionFollowupMessage)) - .withTemperature(0.8f) - .withMaxTokens(400) - .withTopK(10) - .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) - .build(); - - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); - - logger.info("" + response.content()); - assertThat(response).isNotNull(); - assertThat(response.content().get(0).text()).isNotEmpty(); - assertThat(response.content().get(0).text()).contains("Blackbeard"); - assertThat(response.stopReason()).isEqualTo("end_turn"); - assertThat(response.stopSequence()).isNull(); - assertThat(response.usage().inputTokens()).isGreaterThan(30); - assertThat(response.usage().outputTokens()).isGreaterThan(200); - - logger.info("" + response); - } - - @Test - public void chatCompletionStream() { - MediaContent anthropicMessage = new MediaContent("Name 3 famous pirates"); - ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage(List.of(anthropicMessage), Role.USER); - - AnthropicChatRequest request = AnthropicChatRequest.builder(List.of(chatCompletionMessage)) - .withTemperature(0.8f) - .withMaxTokens(300) - .withTopK(10) - .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) - .build(); - - Flux responseStream = anthropicChatApi - .chatCompletionStream(request); - - List responses = responseStream.collectList().block(); - assertThat(responses).isNotNull(); - assertThat(responses).hasSizeGreaterThan(10); - assertThat(responses.stream() - .filter(message -> message.type() == StreamingType.CONTENT_BLOCK_DELTA) - .map(message -> message.delta().text()) - .collect(Collectors.joining())).contains("Blackbeard"); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java index 92060ef5bf..60000b27f8 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java @@ -16,12 +16,7 @@ package org.springframework.ai.bedrock.aot; import org.junit.jupiter.api.Test; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.TypeReference; @@ -42,9 +37,7 @@ void registerHints() { BedrockRuntimeHints bedrockRuntimeHints = new BedrockRuntimeHints(); bedrockRuntimeHints.registerHints(runtimeHints, null); - List classList = Arrays.asList(Ai21Jurassic2ChatBedrockApi.class, CohereChatBedrockApi.class, - CohereEmbeddingBedrockApi.class, LlamaChatBedrockApi.class, TitanChatBedrockApi.class, - TitanEmbeddingBedrockApi.class, AnthropicChatBedrockApi.class); + List classList = Arrays.asList(CohereEmbeddingBedrockApi.class, TitanEmbeddingBedrockApi.class); for (Class aClass : classList) { Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(aClass); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/BedrockConverseApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/BedrockConverseApiIT.java new file mode 100644 index 0000000000..da8a6ef1d2 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/BedrockConverseApiIT.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.api; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.Message; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockConverseApiIT { + + private BedrockConverseApi converseApi = new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id()); + + @Test + public void testConverse() { + ContentBlock contentBlock = ContentBlock.builder().text("Give me the names of 3 famous pirates?").build(); + + Message message = Message.builder().content(contentBlock).role(ConversationRole.USER).build(); + + ConverseRequest request = ConverseRequest.builder() + .modelId("anthropic.claude-3-sonnet-20240229-v1:0") + .messages(List.of(message)) + .build(); + + ConverseResponse response = converseApi.converse(request); + + assertThat(response).isNotNull(); + assertThat(response.output()).isNotNull(); + assertThat(response.output().message()).isNotNull(); + assertThat(response.output().message().content()).isNotEmpty(); + assertThat(response.output().message().content().get(0).text()).contains("Blackbeard"); + assertThat(response.stopReason()).isNotNull(); + assertThat(response.usage()).isNotNull(); + assertThat(response.usage().inputTokens()).isGreaterThan(10); + assertThat(response.usage().outputTokens()).isGreaterThan(30); + } + + @Test + public void testConverseStream() { + ContentBlock contentBlock = ContentBlock.builder().text("Give me the names of 3 famous pirates?").build(); + + Message message = Message.builder().content(contentBlock).role(ConversationRole.USER).build(); + + ConverseStreamRequest request = ConverseStreamRequest.builder() + .modelId("anthropic.claude-3-sonnet-20240229-v1:0") + .messages(List.of(message)) + .build(); + + Flux responseStream = converseApi.converseStream(request); + + List responseOutputs = responseStream.collectList().block(); + + assertThat(responseOutputs).isNotNull(); + assertThat(responseOutputs).hasSizeGreaterThan(10); + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtilsIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtilsIT.java new file mode 100644 index 0000000000..7c79bd1314 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtilsIT.java @@ -0,0 +1,346 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.api; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockConverseApiUtilsIT { + + private static final String FAKE_MODEL_ID = "FAKE_MODEL_ID"; + + @Test + public void testCreateConverseRequestWithNoOptions() { + Prompt prompt = new Prompt("hello world"); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(FAKE_MODEL_ID, prompt); + + assertThat(converseRequest).isNotNull(); + assertThat(converseRequest.system()).isEmpty(); + assertThat(converseRequest.inferenceConfig()).isNull(); + assertThat(converseRequest.toolConfig()).isNull(); + assertThat(converseRequest.additionalModelRequestFields()).isNull(); + assertThat(converseRequest.additionalModelResponseFieldPaths()).isEmpty(); + assertThat(converseRequest.modelId()).isEqualTo(FAKE_MODEL_ID); + assertThat(converseRequest.messages()).hasSize(1); + assertThat(converseRequest.messages().get(0).content()).hasSize(1); + assertThat(converseRequest.messages().get(0).role()).isEqualTo(ConversationRole.USER); + assertThat(converseRequest.messages().get(0).content().get(0).text()).isEqualTo("hello world"); + assertThat(converseRequest.messages().get(0).content().get(0).type()).isEqualTo(Type.TEXT); + } + + @Test + public void testCreateConverseRequestWithMultipleMessagesAndNoOptions() { + Prompt prompt = new Prompt(List.of(new UserMessage("hello world1"), new UserMessage("hello world2"))); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(FAKE_MODEL_ID, prompt); + + assertThat(converseRequest).isNotNull(); + assertThat(converseRequest.system()).isEmpty(); + assertThat(converseRequest.inferenceConfig()).isNull(); + assertThat(converseRequest.toolConfig()).isNull(); + assertThat(converseRequest.additionalModelRequestFields()).isNull(); + assertThat(converseRequest.additionalModelResponseFieldPaths()).isEmpty(); + assertThat(converseRequest.modelId()).isEqualTo(FAKE_MODEL_ID); + assertThat(converseRequest.messages()).hasSize(2); + assertThat(converseRequest.messages().get(0).content()).hasSize(1); + assertThat(converseRequest.messages().get(0).role()).isEqualTo(ConversationRole.USER); + assertThat(converseRequest.messages().get(0).content().get(0).text()).isEqualTo("hello world1"); + assertThat(converseRequest.messages().get(0).content().get(0).type()).isEqualTo(Type.TEXT); + assertThat(converseRequest.messages().get(1).content()).hasSize(1); + assertThat(converseRequest.messages().get(1).role()).isEqualTo(ConversationRole.USER); + assertThat(converseRequest.messages().get(1).content().get(0).text()).isEqualTo("hello world2"); + assertThat(converseRequest.messages().get(1).content().get(0).type()).isEqualTo(Type.TEXT); + } + + @Test + public void testCreateConverseRequestWithMultipleMessageRolesAndNoOptions() { + Prompt prompt = new Prompt(List.of(new UserMessage("hello world1"), new AssistantMessage("hello world2"))); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(FAKE_MODEL_ID, prompt); + + assertThat(converseRequest).isNotNull(); + assertThat(converseRequest.system()).isEmpty(); + assertThat(converseRequest.inferenceConfig()).isNull(); + assertThat(converseRequest.toolConfig()).isNull(); + assertThat(converseRequest.additionalModelRequestFields()).isNull(); + assertThat(converseRequest.additionalModelResponseFieldPaths()).isEmpty(); + assertThat(converseRequest.modelId()).isEqualTo(FAKE_MODEL_ID); + assertThat(converseRequest.messages()).hasSize(2); + assertThat(converseRequest.messages().get(0).content()).hasSize(1); + assertThat(converseRequest.messages().get(0).role()).isEqualTo(ConversationRole.USER); + assertThat(converseRequest.messages().get(0).content().get(0).text()).isEqualTo("hello world1"); + assertThat(converseRequest.messages().get(0).content().get(0).type()).isEqualTo(Type.TEXT); + assertThat(converseRequest.messages().get(1).content()).hasSize(1); + assertThat(converseRequest.messages().get(1).role()).isEqualTo(ConversationRole.ASSISTANT); + assertThat(converseRequest.messages().get(1).content().get(0).text()).isEqualTo("hello world2"); + assertThat(converseRequest.messages().get(1).content().get(0).type()).isEqualTo(Type.TEXT); + } + + @Test + public void testCreateConverseRequestWithSystemMessageAndNoOptions() { + Prompt prompt = new Prompt( + List.of(new UserMessage("hello world"), new SystemMessage("example system message"))); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(FAKE_MODEL_ID, prompt); + + assertThat(converseRequest).isNotNull(); + assertThat(converseRequest.inferenceConfig()).isNull(); + assertThat(converseRequest.toolConfig()).isNull(); + assertThat(converseRequest.additionalModelRequestFields()).isNull(); + assertThat(converseRequest.additionalModelResponseFieldPaths()).isEmpty(); + assertThat(converseRequest.modelId()).isEqualTo(FAKE_MODEL_ID); + assertThat(converseRequest.messages()).hasSize(1); + assertThat(converseRequest.messages().get(0).content()).hasSize(1); + assertThat(converseRequest.messages().get(0).role()).isEqualTo(ConversationRole.USER); + assertThat(converseRequest.messages().get(0).content().get(0).text()).isEqualTo("hello world"); + assertThat(converseRequest.messages().get(0).content().get(0).type()).isEqualTo(Type.TEXT); + assertThat(converseRequest.system()).hasSize(1); + assertThat(converseRequest.system().get(0).text()).isEqualTo("example system message"); + } + + @Test + public void testOptionsToAdditionalModelRequestFields() { + Prompt prompt = new Prompt("hello world"); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(FAKE_MODEL_ID, prompt, + new MockChatOptions()); + + Document requestFields = converseRequest.additionalModelRequestFields(); + + assertThat(converseRequest).isNotNull(); + assertThat(converseRequest.system()).isEmpty(); + assertThat(converseRequest.inferenceConfig()).isNull(); + assertThat(converseRequest.toolConfig()).isNull(); + assertThat(requestFields).isNotNull(); + assertThat(requestFields.asMap()).hasSize(12); + assertThat(requestFields.asMap().get("temperature").asNumber().floatValue()).isEqualTo(0.1F); + assertThat(requestFields.asMap().get("top_p").asNumber().floatValue()).isEqualTo(0.2F); + assertThat(requestFields.asMap().get("top_k").asNumber().intValue()).isEqualTo(3); + assertThat(requestFields.asMap().get("string_value").asString()).isEqualTo("stringValue"); + assertThat(requestFields.asMap().get("boolean_value").asBoolean()).isEqualTo(true); + assertThat(requestFields.asMap().get("long_value").asNumber().longValue()).isEqualTo(4); + assertThat(requestFields.asMap().get("float_value").asNumber().floatValue()).isEqualTo(0.5F); + assertThat(requestFields.asMap().get("double_value").asNumber().doubleValue()).isEqualTo(0.6); + assertThat(requestFields.asMap().get("big_decimal_value").asNumber().bigDecimalValue()) + .isEqualTo(BigDecimal.valueOf(7)); + assertThat(requestFields.asMap().get("big_intege_value").asNumber().bigDecimalValue().intValue()).isEqualTo(8); + assertThat(requestFields.asMap().get("list_value").asList()).hasSize(2); + assertThat(requestFields.asMap().get("list_value").asList().get(0).asString()).isEqualTo("hello"); + assertThat(requestFields.asMap().get("map_value").asMap()).hasSize(1); + assertThat(requestFields.asMap().get("map_value").asMap().get("hello").asString()).isEqualTo("world"); + } + + @Test + public void testCreateConverseRequestWithRuntimeOptions() { + MockChatOptions runtimeOptions = new MockChatOptions(); + runtimeOptions.setTemperature(50F); + + Prompt prompt = new Prompt("hello world", runtimeOptions); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(FAKE_MODEL_ID, prompt, + new MockChatOptions()); + + Document requestFields = converseRequest.additionalModelRequestFields(); + + assertThat(converseRequest).isNotNull(); + assertThat(converseRequest.system()).isEmpty(); + assertThat(converseRequest.inferenceConfig()).isNull(); + assertThat(converseRequest.toolConfig()).isNull(); + assertThat(requestFields).isNotNull(); + assertThat(requestFields.asMap()).hasSize(12); + assertThat(requestFields.asMap().get("temperature").asNumber().floatValue()).isEqualTo(50F); + assertThat(requestFields.asMap().get("top_p").asNumber().floatValue()).isEqualTo(0.2F); + } + + @Test + public void testCreateConverseStreamRequestWithRuntimeOptions() { + MockChatOptions runtimeOptions = new MockChatOptions(); + runtimeOptions.setTemperature(50F); + + Prompt prompt = new Prompt("hello world", runtimeOptions); + + ConverseStreamRequest converseStreamRequest = BedrockConverseApiUtils.createConverseStreamRequest(FAKE_MODEL_ID, + prompt, new MockChatOptions()); + + Document requestFields = converseStreamRequest.additionalModelRequestFields(); + + assertThat(converseStreamRequest).isNotNull(); + assertThat(converseStreamRequest.system()).isEmpty(); + assertThat(converseStreamRequest.inferenceConfig()).isNull(); + assertThat(converseStreamRequest.additionalModelResponseFieldPaths()).isEmpty(); + assertThat(converseStreamRequest.modelId()).isEqualTo(FAKE_MODEL_ID); + assertThat(converseStreamRequest.messages()).hasSize(1); + assertThat(converseStreamRequest.messages().get(0).content()).hasSize(1); + assertThat(converseStreamRequest.messages().get(0).role()).isEqualTo(ConversationRole.USER); + assertThat(converseStreamRequest.messages().get(0).content().get(0).text()).isEqualTo("hello world"); + assertThat(converseStreamRequest.messages().get(0).content().get(0).type()).isEqualTo(Type.TEXT); + assertThat(requestFields.asMap()).hasSize(12); + assertThat(requestFields.asMap().get("temperature").asNumber().floatValue()).isEqualTo(50F); + assertThat(requestFields.asMap().get("top_p").asNumber().floatValue()).isEqualTo(0.2F); + } + + class MockChatOptions implements ChatOptions { + + private @JsonProperty("temperature") Float temperature = 0.1F; + + private @JsonProperty("top_p") Float topP = 0.2F; + + private @JsonProperty("top_k") Integer topK = 3; + + private @JsonProperty("string_value") String stringValue = "stringValue"; + + private @JsonProperty("boolean_value") Boolean booleanValue = true; + + private @JsonProperty("long_value") Long longValue = 4L; + + private @JsonProperty("float_value") Float floatValue = 0.5F; + + private @JsonProperty("double_value") Double doubleValue = 0.6; + + private @JsonProperty("big_decimal_value") BigDecimal bigDecimalValue = BigDecimal.valueOf(7); + + private @JsonProperty("big_intege_value") BigInteger bigIntegerValue = BigInteger.valueOf(8); + + private @JsonProperty("list_value") List listValue = List.of("hello", "world"); + + private @JsonProperty("map_value") Map mapValue = Map.of("hello", "world"); + + @Override + public Float getTemperature() { + return temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + @Override + public Integer getTopK() { + return topK; + } + + public String getStringValue() { + return stringValue; + } + + public void setStringValue(String stringValue) { + this.stringValue = stringValue; + } + + public Boolean getBooleanValue() { + return booleanValue; + } + + public void setBooleanValue(Boolean booleanValue) { + this.booleanValue = booleanValue; + } + + public Long getLongValue() { + return longValue; + } + + public void setLongValue(Long longValue) { + this.longValue = longValue; + } + + public Float getFloatValue() { + return floatValue; + } + + public void setFloatValue(Float floatValue) { + this.floatValue = floatValue; + } + + public Double getDoubleValue() { + return doubleValue; + } + + public void setDoubleValue(Double doubleValue) { + this.doubleValue = doubleValue; + } + + public BigDecimal getBigDecimalValue() { + return bigDecimalValue; + } + + public void setBigDecimalValue(BigDecimal bigDecimalValue) { + this.bigDecimalValue = bigDecimalValue; + } + + public BigInteger getBigIntegerValue() { + return bigIntegerValue; + } + + public void setBigIntegerValue(BigInteger bigIntegerValue) { + this.bigIntegerValue = bigIntegerValue; + } + + public List getListValue() { + return listValue; + } + + public void setListValue(List listValue) { + this.listValue = listValue; + } + + public Map getMapValue() { + return mapValue; + } + + public void setMapValue(Map mapValue) { + this.mapValue = mapValue; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java deleted file mode 100644 index c757efe04a..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatCreateRequestTests.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.cohere; - -import java.time.Duration; -import java.util.List; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.LogitBias; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; -import org.springframework.ai.chat.prompt.Prompt; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -public class BedrockCohereChatCreateRequestTests { - - private CohereChatBedrockApi chatApi = new CohereChatBedrockApi(CohereChatModel.COHERE_COMMAND_V14.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), - Duration.ofMinutes(2)); - - @Test - public void createRequestWithChatOptions() { - - var client = new BedrockCohereChatModel(chatApi, - BedrockCohereChatOptions.builder() - .withTemperature(66.6f) - .withTopK(66) - .withTopP(0.66f) - .withMaxTokens(678) - .withStopSequences(List.of("stop1", "stop2")) - .withReturnLikelihoods(ReturnLikelihoods.ALL) - .withNumGenerations(3) - .withLogitBias(new LogitBias("t", 6.6f)) - .withTruncate(Truncate.END) - .build()); - - CohereChatRequest request = client.createRequest(new Prompt("Test message content"), true); - - assertThat(request.prompt()).isNotEmpty(); - assertThat(request.stream()).isTrue(); - - assertThat(request.temperature()).isEqualTo(66.6f); - assertThat(request.topK()).isEqualTo(66); - assertThat(request.topP()).isEqualTo(0.66f); - assertThat(request.maxTokens()).isEqualTo(678); - assertThat(request.stopSequences()).containsExactly("stop1", "stop2"); - assertThat(request.returnLikelihoods()).isEqualTo(ReturnLikelihoods.ALL); - assertThat(request.numGenerations()).isEqualTo(3); - assertThat(request.logitBias()).isEqualTo(new LogitBias("t", 6.6f)); - assertThat(request.truncate()).isEqualTo(Truncate.END); - - request = client.createRequest(new Prompt("Test message content", - BedrockCohereChatOptions.builder() - .withTemperature(99.9f) - .withTopK(99) - .withTopP(0.99f) - .withMaxTokens(888) - .withStopSequences(List.of("stop3", "stop4")) - .withReturnLikelihoods(ReturnLikelihoods.GENERATION) - .withNumGenerations(13) - .withLogitBias(new LogitBias("t", 9.9f)) - .withTruncate(Truncate.START) - .build()), - false - - ); - - assertThat(request.prompt()).isNotEmpty(); - assertThat(request.stream()).isFalse(); - - assertThat(request.temperature()).isEqualTo(99.9f); - assertThat(request.topK()).isEqualTo(99); - assertThat(request.topP()).isEqualTo(0.99f); - assertThat(request.maxTokens()).isEqualTo(888); - assertThat(request.stopSequences()).containsExactly("stop3", "stop4"); - assertThat(request.returnLikelihoods()).isEqualTo(ReturnLikelihoods.GENERATION); - assertThat(request.numGenerations()).isEqualTo(13); - assertThat(request.logitBias()).isEqualTo(new LogitBias("t", 9.9f)); - assertThat(request.truncate()).isEqualTo(Truncate.START); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java index 5da9f8670d..ec520357f2 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java @@ -21,23 +21,22 @@ import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions.ReturnLikelihoods; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions.Truncate; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; @@ -92,12 +91,8 @@ void multipleStreamAttempts() { @Test void roleTest() { String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; - String name = "Bob"; - String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); - Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Prompt prompt = new Prompt(List.of(userMessage)); ChatResponse response = chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -194,19 +189,50 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void chatResponseUsage() { + Prompt prompt = new Prompt("Who are you?"); + + ChatResponse response = chatModel.call(prompt); + + Usage usage = response.getMetadata().getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isGreaterThan(1); + assertThat(usage.getGenerationTokens()).isGreaterThan(1); + } + + @Test + void chatOptions() { + BedrockCohereChatOptions options = BedrockCohereChatOptions.builder() + .withTemperature(0.5F) + .withTopP(0.5F) + .withTopK(100) + .withMaxTokens(100) + .withStopSequences(List.of("stop sequences")) + .withReturnLikelihoods(ReturnLikelihoods.ALL) + .withNumGenerations(1) + .withTruncate(Truncate.START) + .build(); + + Prompt prompt = new Prompt("Who are you?", options); + ChatResponse response = chatModel.call(prompt); + String content = response.getResult().getOutput().getContent(); + + assertThat(content).isNotNull(); + } + @SpringBootConfiguration public static class TestConfiguration { @Bean - public CohereChatBedrockApi cohereApi() { - return new CohereChatBedrockApi(CohereChatModel.COHERE_COMMAND_V14.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); } @Bean - public BedrockCohereChatModel cohereChatModel(CohereChatBedrockApi cohereApi) { - return new BedrockCohereChatModel(cohereApi); + public BedrockCohereChatModel cohereChatModel(BedrockConverseApi converseApi) { + return new BedrockCohereChatModel(converseApi); } } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModelIT.java new file mode 100644 index 0000000000..03627a49d3 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModelIT.java @@ -0,0 +1,246 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +class BedrockCohereCommandRChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(BedrockCohereCommandRChatModelIT.class); + + @Autowired + private BedrockCohereCommandRChatModel chatModel; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void multipleStreamAttempts() { + + Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + + String joke1 = joke1Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + String joke2 = joke2Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(joke1).isNotBlank(); + assertThat(joke2).isNotBlank(); + } + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors.", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Remove Markdown code blocks from the output. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void functionCallTest() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = BedrockCohereCommandRChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15"); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockCohereCommandRChatModel cohereCommandRChatModel(BedrockConverseApi converseApi) { + return new BedrockCohereCommandRChatModel(converseApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/MockWeatherService.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/MockWeatherService.java new file mode 100644 index 0000000000..a9a4382059 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/MockWeatherService.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Wei Jiang + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request( + @JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city, example: San Francisco") Object location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().toString().contains("Paris")) { + temperature = 15; + } + else if (request.location().toString().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().toString().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, Unit.C); + } + +} \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java deleted file mode 100644 index 540a6bd2bf..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.cohere.api; - -import java.time.Duration; -import java.util.List; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse.Generation.FinishReason; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy;; - -/** - * @author Christian Tzolov - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class CohereChatBedrockApiIT { - - private CohereChatBedrockApi cohereChatApi = new CohereChatBedrockApi(CohereChatModel.COHERE_COMMAND_V14.id(), - Region.US_EAST_1.id(), Duration.ofMinutes(2)); - - @Test - public void requestBuilder() { - - CohereChatRequest request1 = new CohereChatRequest( - "What is the capital of Bulgaria and what is the size? What it the national anthem?", 0.5f, 0.9f, 15, - 40, List.of("END"), CohereChatRequest.ReturnLikelihoods.ALL, false, 1, null, Truncate.NONE); - - var request2 = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withTemperature(0.5f) - .withTopP(0.9f) - .withTopK(15) - .withMaxTokens(40) - .withStopSequences(List.of("END")) - .withReturnLikelihoods(CohereChatRequest.ReturnLikelihoods.ALL) - .withStream(false) - .withNumGenerations(1) - .withLogitBias(null) - .withTruncate(Truncate.NONE) - .build(); - - assertThat(request1).isEqualTo(request2); - } - - @Test - public void chatCompletion() { - - var request = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withStream(false) - .withTemperature(0.5f) - .withTopP(0.8f) - .withTopK(15) - .withMaxTokens(100) - .withStopSequences(List.of("END")) - .withReturnLikelihoods(CohereChatRequest.ReturnLikelihoods.ALL) - .withNumGenerations(3) - .withLogitBias(null) - .withTruncate(Truncate.NONE) - .build(); - - CohereChatResponse response = cohereChatApi.chatCompletion(request); - - assertThat(response).isNotNull(); - assertThat(response.prompt()).isEqualTo(request.prompt()); - assertThat(response.generations()).hasSize(request.numGenerations()); - assertThat(response.generations().get(0).text()).isNotEmpty(); - } - - @Test - public void chatCompletionStream() { - - var request = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withStream(true) - .withTemperature(0.5f) - .withTopP(0.8f) - .withTopK(15) - .withMaxTokens(100) - .withStopSequences(List.of("END")) - .withReturnLikelihoods(CohereChatRequest.ReturnLikelihoods.ALL) - .withNumGenerations(3) - .withLogitBias(null) - .withTruncate(Truncate.NONE) - .build(); - - Flux responseStream = cohereChatApi.chatCompletionStream(request); - List responses = responseStream.collectList().block(); - - assertThat(responses).isNotNull(); - assertThat(responses).hasSizeGreaterThan(10); - assertThat(responses.get(0).text()).isNotEmpty(); - - CohereChatResponse.Generation lastResponse = responses.get(responses.size() - 1); - assertThat(lastResponse.text()).isNull(); - assertThat(lastResponse.isFinished()).isTrue(); - assertThat(lastResponse.finishReason()).isEqualTo(FinishReason.MAX_TOKENS); - assertThat(lastResponse.amazonBedrockInvocationMetrics()).isNotNull(); - } - - @Test - public void testStreamConfigurations() { - var streamRequest = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withStream(true) - .build(); - - assertThatThrownBy(() -> cohereChatApi.chatCompletion(streamRequest)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The request must be configured to return the complete response!"); - - var notStreamRequest = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withStream(false) - .build(); - - assertThatThrownBy(() -> cohereChatApi.chatCompletionStream(notStreamRequest)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The request must be configured to stream the response!"); - - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java index 5ab957e799..a3370c6b3c 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.ai.bedrock.jurassic2; import java.time.Duration; @@ -21,18 +20,20 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel.Ai21Jurassic2ChatModel; +import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatOptions.Penalty; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -61,10 +62,7 @@ class BedrockAi21Jurassic2ChatModelIT { void roleTest() { UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); - Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); - - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Prompt prompt = new Prompt(List.of(userMessage)); ChatResponse response = chatModel.call(prompt); @@ -130,34 +128,63 @@ void mapOutputConverter() { @Test void simpleChatResponse() { UserMessage userMessage = new UserMessage("Tell me a joke about AI."); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); - Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Prompt prompt = new Prompt(List.of(userMessage)); ChatResponse response = chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("AI"); } + @Test + void chatResponseUsage() { + Prompt prompt = new Prompt("Who are you?"); + + ChatResponse response = chatModel.call(prompt); + + Usage usage = response.getMetadata().getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isGreaterThan(1); + assertThat(usage.getGenerationTokens()).isGreaterThan(1); + } + + @Test + void chatOptions() { + BedrockAi21Jurassic2ChatOptions options = BedrockAi21Jurassic2ChatOptions.builder() + .withNumResults(1) + .withMaxTokens(100) + .withMinTokens(1) + .withTemperature(0.5F) + .withTopP(0.5F) + .withTopK(20) + .withStopSequences(List.of("stop sequences")) + .withFrequencyPenalty(Penalty.builder().scale(1F).build()) + .withPresencePenalty(Penalty.builder().scale(1F).build()) + .withCountPenalty(Penalty.builder().scale(1F).build()) + .build(); + + Prompt prompt = new Prompt("Who are you?", options); + ChatResponse response = chatModel.call(prompt); + String content = response.getResult().getOutput().getContent(); + + assertThat(content).isNotNull(); + } + @SpringBootConfiguration public static class TestConfiguration { @Bean - public Ai21Jurassic2ChatBedrockApi jurassic2ChatBedrockApi() { - return new Ai21Jurassic2ChatBedrockApi( - Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); } @Bean - public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel( - Ai21Jurassic2ChatBedrockApi jurassic2ChatBedrockApi) { - return new BedrockAi21Jurassic2ChatModel(jurassic2ChatBedrockApi, + public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel(BedrockConverseApi converseApi) { + return new BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(), converseApi, BedrockAi21Jurassic2ChatOptions.builder() .withTemperature(0.5f) - .withMaxTokens(100) + .withMaxTokens(500) .withTopP(0.9f) .build()); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java deleted file mode 100644 index 8525471d14..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.jurassic2.api; - -import java.time.Duration; -import java.util.stream.Collectors; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class Ai21Jurassic2ChatBedrockApiIT { - - Ai21Jurassic2ChatBedrockApi api = new Ai21Jurassic2ChatBedrockApi(Ai21Jurassic2ChatModel.AI21_J2_ULTRA_V1.id(), - Region.US_EAST_1.id(), Duration.ofMinutes(2)); - - @Test - public void chatCompletion() { - Ai21Jurassic2ChatRequest request = new Ai21Jurassic2ChatRequest("Give me the names of 3 famous pirates?", 0.9f, - 0.9f, 100, null, // List.of("END"), - new Ai21Jurassic2ChatRequest.IntegerScalePenalty(1, true, true, true, true, true), - new Ai21Jurassic2ChatRequest.FloatScalePenalty(0.5f, true, true, true, true, true), - new Ai21Jurassic2ChatRequest.IntegerScalePenalty(1, true, true, true, true, true)); - - Ai21Jurassic2ChatResponse response = api.chatCompletion(request); - - assertThat(response).isNotNull(); - assertThat(response.completions()).isNotEmpty(); - assertThat(response.amazonBedrockInvocationMetrics()).isNull(); - - String responseContent = response.completions() - .stream() - .map(c -> c.data().text()) - .collect(Collectors.joining(System.lineSeparator())); - assertThat(responseContent).contains("Blackbeard"); - } - - // Note: Ai21Jurassic2 doesn't support streaming yet! - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java index 416b397783..3058d36095 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java @@ -21,20 +21,19 @@ import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -195,19 +194,45 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void chatResponseUsage() { + Prompt prompt = new Prompt("Who are you?"); + + ChatResponse response = chatModel.call(prompt); + + Usage usage = response.getMetadata().getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isGreaterThan(1); + assertThat(usage.getGenerationTokens()).isGreaterThan(1); + } + + @Test + void chatOptions() { + BedrockLlamaChatOptions options = BedrockLlamaChatOptions.builder() + .withTemperature(0.5F) + .withTopP(0.5F) + .withMaxGenLen(100) + .build(); + + Prompt prompt = new Prompt("Who are you?", options); + ChatResponse response = chatModel.call(prompt); + String content = response.getResult().getOutput().getContent(); + + assertThat(content).isNotNull(); + } + @SpringBootConfiguration public static class TestConfiguration { @Bean - public LlamaChatBedrockApi llamaApi() { - return new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); } @Bean - public BedrockLlamaChatModel llamaChatModel(LlamaChatBedrockApi llamaApi) { - return new BedrockLlamaChatModel(llamaApi, + public BedrockLlamaChatModel llamaChatModel(BedrockConverseApi converseApi) { + return new BedrockLlamaChatModel(converseApi, BedrockLlamaChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build()); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java deleted file mode 100644 index 4bd48680d2..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaCreateRequestTests.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.llama; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; -import org.springframework.ai.chat.prompt.Prompt; - -import java.time.Duration; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - * @author Wei Jiang - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class BedrockLlamaCreateRequestTests { - - private LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), - Duration.ofMinutes(2)); - - @Test - public void createRequestWithChatOptions() { - - var client = new BedrockLlamaChatModel(api, - BedrockLlamaChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build()); - - var request = client.createRequest(new Prompt("Test message content")); - - assertThat(request.prompt()).isNotEmpty(); - assertThat(request.temperature()).isEqualTo(66.6f); - assertThat(request.topP()).isEqualTo(0.66f); - assertThat(request.maxGenLen()).isEqualTo(666); - - request = client.createRequest(new Prompt("Test message content", - BedrockLlamaChatOptions.builder().withTemperature(99.9f).withMaxGenLen(999).withTopP(0.99f).build())); - - assertThat(request.prompt()).isNotEmpty(); - assertThat(request.temperature()).isEqualTo(99.9f); - assertThat(request.topP()).isEqualTo(0.99f); - assertThat(request.maxGenLen()).isEqualTo(999); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java deleted file mode 100644 index 5b4587358f..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApiIT.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.llama.api; - -import java.time.Duration; -import java.util.List; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - * @author Wei Jiang - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class LlamaChatBedrockApiIT { - - private LlamaChatBedrockApi llamaChatApi = new LlamaChatBedrockApi(LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), - Duration.ofMinutes(2)); - - @Test - public void chatCompletion() { - - LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is") - .withTemperature(0.9f) - .withTopP(0.9f) - .withMaxGenLen(20) - .build(); - - LlamaChatResponse response = llamaChatApi.chatCompletion(request); - - System.out.println(response.generation()); - assertThat(response).isNotNull(); - assertThat(response.generation()).isNotEmpty(); - assertThat(response.generationTokenCount()).isGreaterThan(10); - assertThat(response.generationTokenCount()).isLessThanOrEqualTo(20); - assertThat(response.stopReason()).isNotNull(); - assertThat(response.amazonBedrockInvocationMetrics()).isNull(); - } - - @Test - public void chatCompletionStream() { - - LlamaChatRequest request = new LlamaChatRequest("Hello, my name is", 0.9f, 0.9f, 20); - Flux responseStream = llamaChatApi.chatCompletionStream(request); - List responses = responseStream.collectList().block(); - - assertThat(responses).isNotNull(); - assertThat(responses).hasSizeGreaterThan(10); - assertThat(responses.get(0).generation()).isNotEmpty(); - - LlamaChatResponse lastResponse = responses.get(responses.size() - 1); - assertThat(lastResponse.stopReason()).isNotNull(); - assertThat(lastResponse.amazonBedrockInvocationMetrics()).isNotNull(); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java new file mode 100644 index 0000000000..c1a6e3a834 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java @@ -0,0 +1,246 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.bedrock.MockWeatherService; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockMistralChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(BedrockMistralChatModelIT.class); + + @Autowired + private BedrockMistralChatModel chatModel; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void multipleStreamAttempts() { + + Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + + String joke1 = joke1Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + String joke2 = joke2Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(joke1).isNotBlank(); + assertThat(joke2).isNotBlank(); + } + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors.", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Remove Markdown code blocks from the output. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void functionCallTest() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = BedrockMistralChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15"); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockMistralChatModel mistralChatModel(BedrockConverseApi converseApi) { + return new BedrockMistralChatModel(converseApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java deleted file mode 100644 index af0522de63..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelCreateRequestTests.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.titan; - -import java.time.Duration; -import java.util.List; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; -import org.springframework.ai.chat.prompt.Prompt; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -public class BedrockTitanChatModelCreateRequestTests { - - private TitanChatBedrockApi api = new TitanChatBedrockApi(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), - Duration.ofMinutes(2)); - - @Test - public void createRequestWithChatOptions() { - - var model = new BedrockTitanChatModel(api, - BedrockTitanChatOptions.builder() - .withTemperature(66.6f) - .withTopP(0.66f) - .withMaxTokenCount(666) - .withStopSequences(List.of("stop1", "stop2")) - .build()); - - var request = model.createRequest(new Prompt("Test message content")); - - assertThat(request.inputText()).isNotEmpty(); - assertThat(request.textGenerationConfig().temperature()).isEqualTo(66.6f); - assertThat(request.textGenerationConfig().topP()).isEqualTo(0.66f); - assertThat(request.textGenerationConfig().maxTokenCount()).isEqualTo(666); - assertThat(request.textGenerationConfig().stopSequences()).containsExactly("stop1", "stop2"); - - request = model.createRequest(new Prompt("Test message content", - BedrockTitanChatOptions.builder() - .withTemperature(99.9f) - .withTopP(0.99f) - .withMaxTokenCount(999) - .withStopSequences(List.of("stop3", "stop4")) - .build() - - )); - - assertThat(request.inputText()).isNotEmpty(); - assertThat(request.textGenerationConfig().temperature()).isEqualTo(99.9f); - assertThat(request.textGenerationConfig().topP()).isEqualTo(0.99f); - assertThat(request.textGenerationConfig().maxTokenCount()).isEqualTo(999); - assertThat(request.textGenerationConfig().stopSequences()).containsExactly("stop3", "stop4"); - } - -} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java index 97085c2c2b..f89edc9799 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java @@ -21,7 +21,6 @@ import java.util.Map; import java.util.stream.Collectors; -import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -29,16 +28,14 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; @@ -93,12 +90,8 @@ void multipleStreamAttempts() { @Test void roleTest() { String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; - String name = "Bob"; - String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); - Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Prompt prompt = new Prompt(List.of(userMessage)); ChatResponse response = chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -200,19 +193,46 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void chatResponseUsage() { + Prompt prompt = new Prompt("Who are you?"); + + ChatResponse response = chatModel.call(prompt); + + Usage usage = response.getMetadata().getUsage(); + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isGreaterThan(1); + assertThat(usage.getGenerationTokens()).isGreaterThan(1); + } + + @Test + void chatOptions() { + BedrockTitanChatOptions options = BedrockTitanChatOptions.builder() + .withTemperature(0.5F) + .withTopP(0.5F) + .withMaxTokenCount(100) + .withStopSequences(List.of("stop sequences")) + .build(); + + Prompt prompt = new Prompt("Who are you?", options); + ChatResponse response = chatModel.call(prompt); + String content = response.getResult().getOutput().getContent(); + + assertThat(content).isNotNull(); + } + @SpringBootConfiguration public static class TestConfiguration { @Bean - public TitanChatBedrockApi titanApi() { - return new TitanChatBedrockApi(TitanChatModel.TITAN_TEXT_PREMIER_V1.id(), - EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + public BedrockConverseApi converseApi() { + return new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); } @Bean - public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanApi) { - return new BedrockTitanChatModel(titanApi); + public BedrockTitanChatModel titanChatModel(BedrockConverseApi converseApi) { + return new BedrockTitanChatModel(converseApi); } } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java deleted file mode 100644 index e7bb1f8bff..0000000000 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2023 - 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.ai.bedrock.titan.api; - -import java.time.Duration; -import java.util.List; -import java.util.stream.Collectors; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.regions.Region; - -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") -public class TitanChatBedrockApiIT { - - TitanChatBedrockApi titanBedrockApi = new TitanChatBedrockApi(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), - Region.EU_CENTRAL_1.id(), Duration.ofMinutes(2)); - - TitanChatRequest titanChatRequest = TitanChatRequest.builder("Give me the names of 3 famous pirates?") - .withTemperature(0.5f) - .withTopP(0.9f) - .withMaxTokenCount(100) - .withStopSequences(List.of("|")) - .build(); - - @Test - public void chatCompletion() { - TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); - assertThat(response.results()).hasSize(1); - assertThat(response.results().get(0).outputText()).contains("Blackbeard"); - } - - @Test - public void chatCompletionStream() { - Flux response = titanBedrockApi.chatCompletionStream(titanChatRequest); - List results = response.collectList().block(); - - assertThat(results.stream() - .map(TitanChatResponseChunk::outputText) - .collect(Collectors.joining(System.lineSeparator()))).contains("Blackbeard"); - } - -} diff --git a/pom.xml b/pom.xml index 633d3f69cb..d38ed7043c 100644 --- a/pom.xml +++ b/pom.xml @@ -143,7 +143,7 @@ 1.0.0-beta.8 1.0.0 4.31.1 - 2.25.3 + 2.25.64 2.16.1 0.28.0 1.17.0 diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-anthropic-chat-api.png b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-anthropic-chat-api.png deleted file mode 100644 index c1f83541ec..0000000000 Binary files a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-anthropic-chat-api.png and /dev/null differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-cohere-chat-low-level-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-cohere-chat-low-level-api.jpg deleted file mode 100644 index 0e182f28ad..0000000000 Binary files a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-cohere-chat-low-level-api.jpg and /dev/null differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg deleted file mode 100644 index b836546eb4..0000000000 Binary files a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-llama-chat-api.jpg and /dev/null differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-titan-chat-low-level-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-titan-chat-low-level-api.jpg deleted file mode 100644 index 20a111086e..0000000000 Binary files a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-titan-chat-low-level-api.jpg and /dev/null differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index b554b8959c..631cb44113 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -11,11 +11,16 @@ **** xref:api/chat/functions/azure-open-ai-chat-functions.adoc[Function Calling] *** xref:api/bedrock-chat.adoc[Amazon Bedrock] **** xref:api/chat/bedrock/bedrock-anthropic3.adoc[Anthropic3] +***** xref:api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc[Function Calling] **** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic2] **** xref:api/chat/bedrock/bedrock-llama.adoc[Llama] **** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere] +**** xref:api/chat/bedrock/bedrock-coherecommandr.adoc[CohereCommandR] +***** xref:api/chat/functions/bedrock/bedrock-coherecommandr-chat-functions.adoc[Function Calling] **** xref:api/chat/bedrock/bedrock-titan.adoc[Titan] **** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2] +**** xref:api/chat/bedrock/bedrock-mistral.adoc[Mistral] +***** xref:api/chat/functions/bedrock/bedrock-mistral-chat-functions.adoc[Function Calling] *** xref:api/chat/huggingface.adoc[Hugging Face] *** xref:api/chat/google-vertexai.adoc[Google VertexAI] **** xref:api/chat/vertexai-palm2-chat.adoc[VertexAI PaLM2 ] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc index f8b2b2062a..3fd71cc74a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc @@ -21,8 +21,8 @@ Then add the Spring Boot Starter dependency to your project's Maven `pom.xml` bu [source,xml] ---- - spring-ai-bedrock-ai-spring-boot-starter - org.springframework.ai + org.springframework.ai + spring-ai-bedrock-ai-spring-boot-starter ---- @@ -83,15 +83,13 @@ Here are the supported `` and `` combinations: [cols="|,|,|,|"] |==== -| Model | Chat | Chat Streaming | Embedding - -| llama | Yes | Yes | No -| jurassic2 | Yes | No | No -| cohere | Yes | Yes | Yes +| Model | Chat | Chat Streaming | Embedding | anthropic 2 | Yes | Yes | No -| anthropic 3 | Yes | Yes | No -| jurassic2 (WIP) | Yes | No | No -| titan | Yes | Yes | Yes (however, no batch support) +| anthropic 3 | Yes | Yes | No +| cohere | Yes | Yes | Yes +| jurassic2 | Yes | No | No +| llama | Yes | Yes | No +| titan | Yes | Yes | Yes (however, no batch support) |==== For example, to enable the Bedrock Llama chat model, you need to set `spring.ai.bedrock.llama.chat.enabled=true`. @@ -104,7 +102,9 @@ For more information, refer to the documentation below for each supported model. * xref:api/chat/bedrock/bedrock-anthropic3.adoc[Spring AI Bedrock Anthropic 3 Chat]: `spring.ai.bedrock.anthropic.chat.enabled=true` * xref:api/chat/bedrock/bedrock-llama.adoc[Spring AI Bedrock Llama Chat]: `spring.ai.bedrock.llama.chat.enabled=true` * xref:api/chat/bedrock/bedrock-cohere.adoc[Spring AI Bedrock Cohere Chat]: `spring.ai.bedrock.cohere.chat.enabled=true` +* xref:api/chat/bedrock/bedrock-coherecommandr.adoc[Spring AI Bedrock Cohere Command R Chat]: `spring.ai.bedrock.coherecommandr.chat.enabled=true` * xref:api/embeddings/bedrock-cohere-embedding.adoc[Spring AI Bedrock Cohere Embeddings]: `spring.ai.bedrock.cohere.embedding.enabled=true` * xref:api/chat/bedrock/bedrock-titan.adoc[Spring AI Bedrock Titan Chat]: `spring.ai.bedrock.titan.chat.enabled=true` * xref:api/embeddings/bedrock-titan-embedding.adoc[Spring AI Bedrock Titan Embeddings]: `spring.ai.bedrock.titan.embedding.enabled=true` * xref:api/chat/bedrock/bedrock-jurassic2.adoc[Spring AI Bedrock Ai21 Jurassic2 Chat]: `spring.ai.bedrock.jurassic2.chat.enabled=true` +* xref:api/chat/bedrock/bedrock-mistral.adoc[Spring AI Bedrock Mistral Chat]: `spring.ai.bedrock.mistral.chat.enabled=true` diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc index 63e2e3e895..3403e2ed30 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic.adoc @@ -79,18 +79,16 @@ The prefix `spring.ai.bedrock.anthropic.chat` is the property prefix that config [cols="2,5,1"] |==== | Property | Description | Default - | spring.ai.bedrock.anthropic.chat.enabled | Enable Bedrock Anthropic chat model. Disabled by default | false -| spring.ai.bedrock.anthropic.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java[AnthropicChatModel] for the supported models. | anthropic.claude-v2 -| spring.ai.bedrock.anthropic.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.8 -| spring.ai.bedrock.anthropic.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.anthropic.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java[AnthropicChatModel] for the supported models. | `anthropic.claude-v2` +| spring.ai.bedrock.anthropic.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | AWS Bedrock default +| spring.ai.bedrock.anthropic.chat.options.maxTokens | Specify the maximum number of tokens to use in the generated response. Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. | AWS Bedrock default | spring.ai.bedrock.anthropic.chat.options.topK | Specify the number of token choices the generative uses to generate the next token. | AWS Bedrock default -| spring.ai.bedrock.anthropic.chat.options.stopSequences | Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops generating further tokens. The returned text doesn't contain the stop sequence. | 10 -| spring.ai.bedrock.anthropic.chat.options.anthropicVersion | The version of the generative to use. | bedrock-2023-05-31 -| spring.ai.bedrock.anthropic.chat.options.maxTokensToSample | Specify the maximum number of tokens to use in the generated response. Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. | 500 +| spring.ai.bedrock.anthropic.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.anthropic.chat.options.stopSequences | Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops generating further tokens. The returned text doesn't contain the stop sequence. | AWS Bedrock default |==== -Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java[AnthropicChatModel] for other model IDs. +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModel.java[AnthropicChatModel] for other model IDs. Supported values are: `anthropic.claude-instant-v1`, `anthropic.claude-v2` and `anthropic.claude-v2:1`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. @@ -195,20 +193,18 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp [source,java] ---- -AnthropicChatBedrockApi anthropicApi = new AnthropicChatBedrockApi( - AnthropicChatBedrockApi.AnthropicModel.CLAUDE_V2.id(), +BedrockConverseApi converseApi = new BedrockConverseApi( EnvironmentVariableCredentialsProvider.create(), Region.EU_CENTRAL_1.id(), - new ObjectMapper(), Duration.ofMillis(1000L)); -BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(anthropicApi, +BedrockAnthropicChatModel chatModel = new BedrockAnthropicChatModel(converseApi, AnthropicChatOptions.builder() .withTemperature(0.6f) + .withMaxTokens(100) .withTopK(10) .withTopP(0.8f) - .withMaxTokensToSample(100) - .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) + .withStopSequences(List.of("stop sequences")) .build()); ChatResponse response = chatModel.call( @@ -218,37 +214,3 @@ ChatResponse response = chatModel.call( Flux response = chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- - -=== Low-level AnthropicChatBedrockApi Client [[low-level-api]] - -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java[AnthropicChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock link:https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html[Anthropic Claude models]. - -Following class diagram illustrates the AnthropicChatBedrockApi interface and building blocks: - -image::bedrock/bedrock-anthropic-chat-api.png[AnthropicChatBedrockApi Class Diagram] - -Client supports the `anthropic.claude-instant-v1`, `anthropic.claude-v2` and `anthropic.claude-v2:1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. - -Here is a simple snippet how to use the api programmatically: - -[source,java] ----- -AnthropicChatBedrockApi anthropicChatApi = new AnthropicChatBedrockApi( - AnthropicModel.CLAUDE_V2.id(), Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); - -AnthropicChatRequest request = AnthropicChatRequest - .builder(String.format(AnthropicChatBedrockApi.PROMPT_TEMPLATE, "Name 3 famous pirates")) - .withTemperature(0.8f) - .withMaxTokensToSample(300) - .withTopK(10) - .build(); - -// Sync request -AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); - -// Streaming request -Flux responseStream = anthropicChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); ----- - -Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java[AnthropicChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc index 2f547788ea..98af263183 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-anthropic3.adoc @@ -79,16 +79,15 @@ The prefix `spring.ai.bedrock.anthropic3.chat` is the property prefix that confi | spring.ai.bedrock.anthropic3.chat.enabled | Enable Bedrock Anthropic chat model. Disabled by default | false | spring.ai.bedrock.anthropic3.chat.model | The model id to use. Supports the `anthropic.claude-3-sonnet-20240229-v1:0`,`anthropic.claude-3-haiku-20240307-v1:0` and the legacy `anthropic.claude-v2`, `anthropic.claude-v2:1` and `anthropic.claude-instant-v1` models for both synchronous and streaming responses. | `anthropic.claude-3-sonnet-20240229-v1:0` -| spring.ai.bedrock.anthropic3.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.8 -| spring.ai.bedrock.anthropic3.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default -| spring.ai.bedrock.anthropic3.chat.options.top-k | Specify the number of token choices the generative uses to generate the next token. | AWS Bedrock default -| spring.ai.bedrock.anthropic3.chat.options.stop-sequences | Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops generating further tokens. The returned text doesn't contain the stop sequence. | 10 -| spring.ai.bedrock.anthropic3.chat.options.anthropic-version | The version of the generative to use. | bedrock-2023-05-31 -| spring.ai.bedrock.anthropic3.chat.options.max-tokens | Specify the maximum number of tokens to use in the generated response. Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. | 500 +| spring.ai.bedrock.anthropic3.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | AWS Bedrock default +| spring.ai.bedrock.anthropic3.chat.options.maxTokens | Specify the maximum number of tokens to use in the generated response. Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. | AWS Bedrock default +| spring.ai.bedrock.anthropic3.chat.options.topK | Specify the number of token choices the generative uses to generate the next token. | AWS Bedrock default +| spring.ai.bedrock.anthropic3.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.anthropic3.chat.options.stopSequences | Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops generating further tokens. The returned text doesn't contain the stop sequence. | AWS Bedrock default |==== -Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[AnthropicChatModel] for other model IDs. -Supported values are: `anthropic.claude-instant-v1`, `anthropic.claude-v2` and `anthropic.claude-v2:1`. +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java[Anthropic3ChatModel] for other model IDs. +Supported values are: `anthropic.claude-3-5-sonnet-20240620-v1:0`, `anthropic.claude-3-opus-20240229-v1:0`, `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. TIP: All properties prefixed with `spring.ai.bedrock.anthropic3.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. @@ -236,20 +235,18 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp [source,java] ---- -Anthropic3ChatBedrockApi anthropicApi = new Anthropic3ChatBedrockApi( - AnthropicChatBedrockApi.AnthropicModel.CLAUDE_V3_SONNET.id(), +BedrockConverseApi converseApi = new BedrockConverseApi( EnvironmentVariableCredentialsProvider.create(), - Region.US_EAST_1.id(), - new ObjectMapper(), + Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); -BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(anthropicApi, +BedrockAnthropic3ChatModel chatModel = new BedrockAnthropic3ChatModel(converseApi, AnthropicChatOptions.builder() .withTemperature(0.6f) + .withMaxTokens(100) .withTopK(10) .withTopP(0.8f) - .withMaxTokensToSample(100) - .withAnthropicVersion(AnthropicChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) + .withStopSequences(List.of("stop sequences")) .build()); ChatResponse response = chatModel.call( @@ -259,33 +256,3 @@ ChatResponse response = chatModel.call( Flux response = chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- - -=== Low-level Anthropic3ChatBedrockApi Client [[low-level-api]] - -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[Anthropic3ChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock link:https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html[Anthropic Claude models]. - -Client supports the `anthropic.claude-3-opus-20240229-v1:0`,`anthropic.claude-3-sonnet-20240229-v1:0`,`anthropic.claude-3-haiku-20240307-v1:0` and the legacy `anthropic.claude-v2`, `anthropic.claude-v2:1` and `anthropic.claude-instant-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. - -Here is a simple snippet how to use the api programmatically: - -[source,java] ----- -Anthropic3ChatBedrockApi anthropicChatApi = new Anthropic3ChatBedrockApi( - AnthropicModel.CLAUDE_V2.id(), Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); - -AnthropicChatRequest request = AnthropicChatRequest - .builder(String.format(Anthropic3ChatBedrockApi.PROMPT_TEMPLATE, "Name 3 famous pirates")) - .withTemperature(0.8f) - .withMaxTokensToSample(300) - .withTopK(10) - .build(); - -// Sync request -AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); - -// Streaming request -Flux responseStream = anthropicChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); ----- - -Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java[Anthropic3ChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc index c4345a46c4..e62751c113 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-cohere.adoc @@ -71,8 +71,8 @@ The prefix `spring.ai.bedrock.cohere.chat` is the property prefix that configure | Property | Description | Default | spring.ai.bedrock.cohere.chat.enabled | Enable or disable support for Cohere | false -| spring.ai.bedrock.cohere.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java#L326C14-L326C29[CohereChatModel] for the supported models. | cohere.command-text-v14 -| spring.ai.bedrock.cohere.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.7 +| spring.ai.bedrock.cohere.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java[CohereChatModel] for the supported models. | `cohere.command-text-v14` +| spring.ai.bedrock.cohere.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | AWS Bedrock default | spring.ai.bedrock.cohere.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default | spring.ai.bedrock.cohere.chat.options.topK | Specify the number of token choices the model uses to generate the next token | AWS Bedrock default | spring.ai.bedrock.cohere.chat.options.maxTokens | Specify the maximum number of tokens to use in the generated response. | AWS Bedrock default @@ -83,7 +83,7 @@ The prefix `spring.ai.bedrock.cohere.chat` is the property prefix that configure | spring.ai.bedrock.cohere.chat.options.truncate | Specifies how the API handles inputs longer than the maximum token length | AWS Bedrock default |==== -Look at the https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java#L326C14-L326C29[CohereChatModel] for other model IDs. +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java[CohereChatModel] for other model IDs. Supported values are: `cohere.command-light-text-v14` and `cohere.command-text-v14`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. @@ -187,13 +187,12 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp [source,java] ---- -CohereChatBedrockApi api = new CohereChatBedrockApi(CohereChatModel.COHERE_COMMAND_V14.id(), - EnvironmentVariableCredentialsProvider.create(), - Region.US_EAST_1.id(), - new ObjectMapper(), - Duration.ofMillis(1000L)); +BedrockConverseApi converseApi = new BedrockConverseApi( + EnvironmentVariableCredentialsProvider.create(), + Region.EU_CENTRAL_1.id(), + Duration.ofMillis(1000L)); -BedrockCohereChatModel chatModel = new BedrockCohereChatModel(api, +BedrockCohereChatModel chatModel = new BedrockCohereChatModel(converseApi, BedrockCohereChatOptions.builder() .withTemperature(0.6f) .withTopK(10) @@ -208,58 +207,3 @@ ChatResponse response = chatModel.call( Flux response = chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- - -== Low-level CohereChatBedrockApi Client [[low-level-api]] - -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java[CohereChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html[Cohere Command models]. - -Following class diagram illustrates the CohereChatBedrockApi interface and building blocks: - -image::bedrock/bedrock-cohere-chat-low-level-api.jpg[align="center", width="800px"] - -The CohereChatBedrockApi supports the `cohere.command-light-text-v14` and `cohere.command-text-v14` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) requests. - -Here is a simple snippet how to use the api programmatically: - -[source,java] ----- -CohereChatBedrockApi cohereChatApi = new CohereChatBedrockApi( - CohereChatModel.COHERE_COMMAND_V14.id(), - Region.US_EAST_1.id(), - Duration.ofMillis(1000L)); - -var request = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withStream(false) - .withTemperature(0.5f) - .withTopP(0.8f) - .withTopK(15) - .withMaxTokens(100) - .withStopSequences(List.of("END")) - .withReturnLikelihoods(CohereChatRequest.ReturnLikelihoods.ALL) - .withNumGenerations(3) - .withLogitBias(null) - .withTruncate(Truncate.NONE) - .build(); - -CohereChatResponse response = cohereChatApi.chatCompletion(request); - -var request = CohereChatRequest - .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") - .withStream(true) - .withTemperature(0.5f) - .withTopP(0.8f) - .withTopK(15) - .withMaxTokens(100) - .withStopSequences(List.of("END")) - .withReturnLikelihoods(CohereChatRequest.ReturnLikelihoods.ALL) - .withNumGenerations(3) - .withLogitBias(null) - .withTruncate(Truncate.NONE) - .build(); - -Flux responseStream = cohereChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); ----- - - diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-coherecommandr.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-coherecommandr.adoc new file mode 100644 index 0000000000..a14c8cd1b6 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-coherecommandr.adoc @@ -0,0 +1,213 @@ += Cohere Command R Chat + +Provides Bedrock Cohere Command R Chat model. +Integrate generative AI capabilities into essential apps and workflows that improve business outcomes. + +The https://aws.amazon.com/bedrock/cohere-command-embed/[AWS Bedrock Cohere Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model. + +== Prerequisites + +Refer to the xref:api/bedrock.adoc[Spring AI documentation on Amazon Bedrock] for setting up API access. + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Add the `spring-ai-bedrock-ai-spring-boot-starter` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock-ai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock-ai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Enable Cohere Command R Chat Support + +By default the Cohere Command R model is disabled. +To enable it set the `spring.ai.bedrock.coherecommandr.chat.enabled` property to `true`. +Exporting environment variable is one way to set this configuration property: + +[source,shell] +---- +export SPRING_AI_BEDROCK_COHERECOMMANDR_CHAT_ENABLED=true +---- + +=== Chat Properties + +The prefix `spring.ai.bedrock.aws` is the property prefix to configure the connection to AWS Bedrock. + +[cols="3,3,3"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.aws.region | AWS region to use. | us-east-1 +| spring.ai.bedrock.aws.timeout | AWS timeout to use. | 5m +| spring.ai.bedrock.aws.access-key | AWS access key. | - +| spring.ai.bedrock.aws.secret-key | AWS secret key. | - +|==== + +The prefix `spring.ai.bedrock.coherecommandr.chat` is the property prefix that configures the chat model implementation for Cohere Command R. + +[cols="2,5,1"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.coherecommandr.chat.enabled | Enable or disable support for Cohere Command R | false +| spring.ai.bedrock.coherecommandr.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java[CohereCommandRChatModel] for the supported models. | cohere.command-r-plus-v1:0 +| spring.ai.bedrock.coherecommandr.chat.options.searchQueriesOnly | When enabled, it will only generate potential search queries without performing searches or providing a response. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.preamble | Overrides the default preamble for search query generation. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.maxToken | Specify the maximum number of tokens to use in the generated response. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.topK | Specify the number of token choices the model uses to generate the next token | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.promptTruncation | Dictates how the prompt is constructed. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.frequencyPenalty | Used to reduce repetitiveness of generated tokens. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.presencePenalty | Used to reduce repetitiveness of generated tokens. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.seed | Specify the best effort to sample tokens deterministically. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.returnPrompt | Specify true to return the full prompt that was sent to the model. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.stopSequences | Configure up to four sequences that the model recognizes. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.rawPrompting | Specify true, to send the user’s message to the model without any preprocessing. | AWS Bedrock default +|==== + +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java[CohereCommandRChatModel] for other model IDs. +Supported values are: `cohere.command-r-plus-v1:0` and `cohere.command-r-v1:0`. +Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html[AWS Bedrock documentation for base model IDs]. + +TIP: All properties prefixed with `spring.ai.bedrock.coherecommandr.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java[BedrockCohereCommandRChatOptions.java] provides model configurations, such as temperature, topK, topP, etc. + +On start-up, the default options can be configured with the `BedrockCohereCommandRChatModel(api, options)` constructor or the `spring.ai.bedrock.coherecommandr.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + BedrockCohereCommandRChatOptions.builder() + .withTemperature(0.4) + .build() + )); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java[BedrockCohereCommandRChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-bedrock-ai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Cohere Command R chat model: + +[source] +---- +spring.ai.bedrock.aws.region=eu-central-1 +spring.ai.bedrock.aws.timeout=1000ms +spring.ai.bedrock.aws.access-key=${AWS_ACCESS_KEY_ID} +spring.ai.bedrock.aws.secret-key=${AWS_SECRET_ACCESS_KEY} + +spring.ai.bedrock.coherecommandr.chat.enabled=true +spring.ai.bedrock.coherecommandr.chat.options.temperature=0.8 +---- + +TIP: replace the `regions`, `access-key` and `secret-key` with your AWS credentials. + +This will create a `BedrockCohereCommandRChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final BedrockCohereCommandRChatModel chatModel; + + @Autowired + public ChatController(BedrockCohereCommandRChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + Prompt prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java[BedrockCohereCommandRChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the Bedrock Cohere Command R service. + +Add the `spring-ai-bedrock` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java[BedrockCohereCommandRChatModel] and use it for text generations: + +[source,java] +---- +BedrockConverseApi converseApi = new BedrockConverseApi( + EnvironmentVariableCredentialsProvider.create(), + Region.EU_CENTRAL_1.id(), + Duration.ofMillis(1000L)); + +BedrockCohereCommandRChatModel chatModel = new BedrockCohereCommandRChatModel(converseApi, + BedrockCohereCommandRChatOptions.builder() + .withTemperature(0.6f) + .withTopK(10) + .withTopP(0.5f) + .withMaxTokens(678) + .build() + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux response = chatModel.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc index d1d8956ce5..ca0c112e40 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-jurassic2.adoc @@ -71,13 +71,20 @@ The prefix `spring.ai.bedrock.jurassic2.chat` is the property prefix that config | Property | Description | Default | spring.ai.bedrock.jurassic2.chat.enabled | Enable or disable support for Jurassic-2 | false -| spring.ai.bedrock.jurassic2.chat.model | The model id to use (See Below) | ai21.j2-mid-v1 -| spring.ai.bedrock.jurassic2.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | 0.7 -| spring.ai.bedrock.jurassic2.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | AWS Bedrock default -| spring.ai.bedrock.jurassic2.chat.options.max-tokens | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxTokens. | 500 +| spring.ai.bedrock.jurassic2.chat.model | The model id to use | `ai21.j2-mid-v1` +| spring.ai.bedrock.jurassic2.chat.options.numResults | Number of completions to sample and return | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.maxTokens | The maximum number of tokens to generate per result | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.minTokens | The minimum number of tokens to generate per result | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.temperature | Modifies the distribution from which tokens are sampled | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.topP | Sample tokens from the corresponding top percentile of probability mass | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.topKReturn | Return the top-K (topKReturn) alternative tokens | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.stopSequences | Stops decoding if any of the strings is generated | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.frequencyPenalty | Penalty object for frequency | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.presencePenalty | Penalty object for presence | AWS Bedrock default +| spring.ai.bedrock.jurassic2.chat.options.countPenalty | Penalty object for count | AWS Bedrock default |==== -Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java#L164[Ai21Jurassic2ChatBedrockApi#Ai21Jurassic2ChatModel] for other model IDs. The other value supported is `ai21.j2-ultra-v1`. +Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModel.java[Ai21Jurassic2ChatModel] for other model IDs. The other value supported is `ai21.j2-ultra-v1`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. TIP: All properties prefixed with `spring.ai.bedrock.jurassic2.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. @@ -175,48 +182,25 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp [source,java] ---- -Ai21Jurassic2ChatBedrockApi api = new Ai21Jurassic2ChatBedrockApi(Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(), +BedrockConverseApi converseApi = new BedrockConverseApi( EnvironmentVariableCredentialsProvider.create(), - Region.US_EAST_1.id(), - new ObjectMapper(), + Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); -BedrockAi21Jurassic2ChatModel chatModel = new BedrockAi21Jurassic2ChatModel(api, +BedrockAi21Jurassic2ChatModel chatModel = new BedrockAi21Jurassic2ChatModel(converseApi, BedrockAi21Jurassic2ChatOptions.builder() - .withTemperature(0.5f) + .withNumResults(1) .withMaxTokens(100) - .withTopP(0.9f).build()); + .withMinTokens(1) + .withTemperature(0.5F) + .withTopP(0.5F) + .withTopK(20) + .withStopSequences(List.of("stop sequences")) + .withFrequencyPenalty(Penalty.builder().scale(1F).build()) + .withPresencePenalty(Penalty.builder().scale(1F).build()) + .withCountPenalty(Penalty.builder().scale(1F).build()) + .build(); ChatResponse response = chatModel.call( new Prompt("Generate the names of 5 famous pirates.")); - ---- - -== Low-level Client [[low-level-api]] - -https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java[Ai21Jurassic2ChatBedrockApi] provides a lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html[Jurassic-2 and Jurassic-2 Chat models]. - -The `Ai21Jurassic2ChatBedrockApi` supports the `ai21.j2-mid-v1` and `ai21.j2-ultra-v1` models and only support synchronous ( `chatCompletion()`). - -Here is a simple snippet on how to use the API programmatically: - - -[source,java] ----- -Ai21Jurassic2ChatBedrockApi jurassic2ChatApi = new Ai21Jurassic2ChatBedrockApi( - Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(), - Region.US_EAST_1.id(), - Duration.ofMillis(1000L)); - -Ai21Jurassic2ChatRequest request = Ai21Jurassic2ChatRequest.builder("Hello, my name is") - .withTemperature(0.9f) - .withTopP(0.9f) - .withMaxTokens(20) - .build(); - -Ai21Jurassic2ChatResponse response = jurassic2ChatApi.chatCompletion(request); - - ----- - -Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java[Ai21Jurassic2ChatBedrockApi.java]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc index d8a0c63476..257dabd3e5 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-llama.adoc @@ -76,13 +76,13 @@ The prefix `spring.ai.bedrock.llama.chat` is the property prefix that configures | Property | Description | Default | spring.ai.bedrock.llama.chat.enabled | Enable or disable support for Llama | false -| spring.ai.bedrock.llama.chat.model | The model id to use (See Below) | meta.llama3-70b-instruct-v1:0 -| spring.ai.bedrock.llama.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | 0.7 +| spring.ai.bedrock.llama.chat.model | The model id to use (See Below) | `meta.llama3-70b-instruct-v1:0` +| spring.ai.bedrock.llama.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the model. This value specifies default to be used by the backend while making the call to the model. | AWS Bedrock default | spring.ai.bedrock.llama.chat.options.top-p | The maximum cumulative probability of tokens to consider when sampling. The model uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | AWS Bedrock default -| spring.ai.bedrock.llama.chat.options.max-gen-len | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxGenLen. | 300 +| spring.ai.bedrock.llama.chat.options.max-gen-len | Specify the maximum number of tokens to use in the generated response. The model truncates the response once the generated text exceeds maxGenLen. | AWS Bedrock default |==== -Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java#L164[LlamaChatBedrockApi#LlamaChatModel] for other model IDs. The other value supported is `meta.llama2-13b-chat-v1`. +Look at https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModel.java[LlamaChatModel] for other model IDs. The other value supported is `meta.llama2-13b-chat-v1`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. TIP: All properties prefixed with `spring.ai.bedrock.llama.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. @@ -185,13 +185,12 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp [source,java] ---- -LlamaChatBedrockApi api = new LlamaChatBedrockApi(LlamaChatModel.LLAMA2_70B_CHAT_V1.id(), - EnvironmentVariableCredentialsProvider.create(), - Region.US_EAST_1.id(), - new ObjectMapper(), - Duration.ofMillis(1000L)); +BedrockConverseApi converseApi = new BedrockConverseApi( + EnvironmentVariableCredentialsProvider.create(), + Region.EU_CENTRAL_1.id(), + Duration.ofMillis(1000L)); -BedrockLlamaChatModel chatModel = new BedrockLlamaChatModel(api, +BedrockLlamaChatModel chatModel = new BedrockLlamaChatModel(converseApi, BedrockLlamaChatOptions.builder() .withTemperature(0.5f) .withMaxGenLen(100) @@ -204,39 +203,3 @@ ChatResponse response = chatModel.call( Flux response = chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- - -== Low-level LlamaChatBedrockApi Client [[low-level-api]] - -https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html[Meta Llama 2 and Llama 2 Chat models]. - -Following class diagram illustrates the LlamaChatBedrockApi interface and building blocks: - -image::bedrock/bedrock-llama-chat-api.jpg[LlamaChatBedrockApi Class Diagram] - -The LlamaChatBedrockApi supports the `meta.llama3-8b-instruct-v1:0`,`meta.llama3-70b-instruct-v1:0`,`meta.llama2-13b-chat-v1` and `meta.llama2-70b-chat-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. - -Here is a simple snippet how to use the api programmatically: - -[source,java] ----- -LlamaChatBedrockApi llamaChatApi = new LlamaChatBedrockApi( - LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), - Region.US_EAST_1.id(), - Duration.ofMillis(1000L)); - -LlamaChatRequest request = LlamaChatRequest.builder("Hello, my name is") - .withTemperature(0.9f) - .withTopP(0.9f) - .withMaxGenLen(20) - .build(); - -LlamaChatResponse response = llamaChatApi.chatCompletion(request); - -// Streaming response -Flux responseStream = llamaChatApi.chatCompletionStream(request); -List responses = responseStream.collectList().block(); ----- - -Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java[LlamaChatBedrockApi.java]'s JavaDoc for further information. - - diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-mistral.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-mistral.adoc new file mode 100644 index 0000000000..95fcc60539 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-mistral.adoc @@ -0,0 +1,206 @@ += Mistral Chat + +Provides Bedrock Mistral chat model. +Integrate generative AI capabilities into essential apps and workflows that improve business outcomes. + +The https://aws.amazon.com/bedrock/mistral/[AWS Bedrock Mistral Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model. + +== Prerequisites + +Refer to the xref:api/bedrock.adoc[Spring AI documentation on Amazon Bedrock] for setting up API access. + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Add the `spring-ai-bedrock-ai-spring-boot-starter` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock-ai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock-ai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Enable Mistral Chat Support + +By default the Mistral model is disabled. +To enable it set the `spring.ai.bedrock.mistral.chat.enabled` property to `true`. +Exporting environment variable is one way to set this configuration property: + +[source,shell] +---- +export SPRING_AI_BEDROCK_MISTRAL_CHAT_ENABLED=true +---- + +=== Chat Properties + +The prefix `spring.ai.bedrock.aws` is the property prefix to configure the connection to AWS Bedrock. + +[cols="3,3,3"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.aws.region | AWS region to use. | us-east-1 +| spring.ai.bedrock.aws.timeout | AWS timeout to use. | 5m +| spring.ai.bedrock.aws.access-key | AWS access key. | - +| spring.ai.bedrock.aws.secret-key | AWS secret key. | - +|==== + +The prefix `spring.ai.bedrock.mistral.chat` is the property prefix that configures the chat model implementation for Mistral. + +[cols="2,5,1"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.mistral.chat.enabled | Enable or disable support for Mistral | false +| spring.ai.bedrock.mistral.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java[MistralChatModel] for the supported models. | mistral.mistral-large-2402-v1:0 +| spring.ai.bedrock.mistral.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.topK | Specify the number of token choices the model uses to generate the next token | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.maxTokens | Specify the maximum number of tokens to use in the generated response. | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.stopSequences | Configure up to four sequences that the model recognizes. | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.stopSequences | Specifies how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. | AWS Bedrock default +|==== + +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java[MistralChatModel] for other model IDs. +Supported values are: `mistral.mistral-7b-instruct-v0:2`, `mistral.mixtral-8x7b-instruct-v0:1`, `mistral.mistral-large-2402-v1:0` and `mistral.mistral-small-2402-v1:0`. +Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. + +TIP: All properties prefixed with `spring.ai.bedrock.mistral.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java[BedrockMistralChatOptions.java] provides model configurations, such as temperature, topK, topP, etc. + +On start-up, the default options can be configured with the `BedrockMistralChatModel(api, options)` constructor or the `spring.ai.bedrock.mistral.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + BedrockMistralChatOptions.builder() + .withTemperature(0.4) + .build() + )); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java[BedrockMistralChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-bedrock-ai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Mistral chat model: + +[source] +---- +spring.ai.bedrock.aws.region=eu-central-1 +spring.ai.bedrock.aws.timeout=1000ms +spring.ai.bedrock.aws.access-key=${AWS_ACCESS_KEY_ID} +spring.ai.bedrock.aws.secret-key=${AWS_SECRET_ACCESS_KEY} + +spring.ai.bedrock.mistral.chat.enabled=true +spring.ai.bedrock.mistral.chat.options.temperature=0.8 +---- + +TIP: replace the `regions`, `access-key` and `secret-key` with your AWS credentials. + +This will create a `BedrockMistralChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final BedrockMistralChatModel chatModel; + + @Autowired + public ChatController(BedrockMistralChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + Prompt prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java[BedrockMistralChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the Bedrock Mistral service. + +Add the `spring-ai-bedrock` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java[BedrockMistralChatModel] and use it for text generations: + +[source,java] +---- +BedrockConverseApi converseApi = new BedrockConverseApi( + EnvironmentVariableCredentialsProvider.create(), + Region.EU_CENTRAL_1.id(), + Duration.ofMillis(1000L)); + +BedrockMistralChatModel chatModel = new BedrockMistralChatModel(converseApi, + BedrockMistralChatOptions.builder() + .withTemperature(0.6f) + .withTopK(10) + .withTopP(0.5f) + .withMaxTokens(678) + .build() + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux response = chatModel.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc index 45f836b4c4..3afa94afa4 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-titan.adoc @@ -70,16 +70,15 @@ The prefix `spring.ai.bedrock.titan.chat` is the property prefix that configures [cols="3,4,1"] |==== | Property | Description | Default - | spring.ai.bedrock.titan.chat.enabled | Enable Bedrock Titan chat model. Disabled by default | false -| spring.ai.bedrock.titan.chat.model | The model id to use. See the link:https://github.com/spring-projects/spring-ai/blob/4839a6175cd1ec89498b97d3efb6647022c3c7cb/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java#L220[TitanChatBedrockApi#TitanChatModel] for the supported models. | amazon.titan-text-lite-v1 -| spring.ai.bedrock.titan.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.7 +| spring.ai.bedrock.titan.chat.model | The model id to use. See the link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java[TitanChatModel] for the supported models. | `amazon.titan-text-express-v1` +| spring.ai.bedrock.titan.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | AWS Bedrock default | spring.ai.bedrock.titan.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default | spring.ai.bedrock.titan.chat.options.stopSequences | Configure up to four sequences that the generative recognizes. After a stop sequence, the generative stops generating further tokens. The returned text doesn't contain the stop sequence. | AWS Bedrock default | spring.ai.bedrock.titan.chat.options.maxTokenCount | Specify the maximum number of tokens to use in the generated response. Note that the models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We recommend a limit of 4,000 tokens for optimal performance. | AWS Bedrock default |==== -Look at the https://github.com/spring-projects/spring-ai/blob/4839a6175cd1ec89498b97d3efb6647022c3c7cb/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java#L220[TitanChatBedrockApi#TitanChatModel] for other model IDs. +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java[TitanChatModel] for other model IDs. Supported values are: `amazon.titan-text-lite-v1`, `amazon.titan-text-express-v1` and `amazon.titan-text-premier-v1:0`. Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. @@ -183,14 +182,12 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp [source,java] ---- -TitanChatBedrockApi titanApi = new TitanChatBedrockApi( - TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), - EnvironmentVariableCredentialsProvider.create(), - Region.US_EAST_1.id(), - new ObjectMapper(), +BedrockConverseApi converseApi = new BedrockConverseApi( + EnvironmentVariableCredentialsProvider.create(), + Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); -BedrockTitanChatModel chatModel = new BedrockTitanChatModel(titanApi, +BedrockTitanChatModel chatModel = new BedrockTitanChatModel(converseApi, BedrockTitanChatOptions.builder() .withTemperature(0.6f) .withTopP(0.8f) @@ -204,36 +201,3 @@ ChatResponse response = chatModel.call( Flux response = chatModel.stream( new Prompt("Generate the names of 5 famous pirates.")); ---- - -== Low-level TitanChatBedrockApi Client [[low-level-api]] - -The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java[TitanChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock link:https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html[Bedrock Titan models]. - -Following class diagram illustrates the TitanChatBedrockApi interface and building blocks: - -image::bedrock/bedrock-titan-chat-low-level-api.jpg[width=800,align="center"] - -Client supports the `amazon.titan-text-lite-v1` and `amazon.titan-text-express-v1` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses. - -Here is a simple snippet how to use the api programmatically: - -[source,java] ----- -TitanChatBedrockApi titanBedrockApi = new TitanChatBedrockApi(TitanChatCompletionModel.TITAN_TEXT_EXPRESS_V1.id(), - Region.EU_CENTRAL_1.id(), Duration.ofMillis(1000L)); - -TitanChatRequest titanChatRequest = TitanChatRequest.builder("Give me the names of 3 famous pirates?") - .withTemperature(0.5f) - .withTopP(0.9f) - .withMaxTokenCount(100) - .withStopSequences(List.of("|")) - .build(); - -TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); - -Flux response = titanBedrockApi.chatCompletionStream(titanChatRequest); - -List results = response.collectList().block(); ----- - -Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java[TitanChatBedrockApi]'s JavaDoc for further information. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc index 8906aee914..5fe288aea6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc @@ -6,7 +6,7 @@ The `claude-3-opus`, `claude-3-sonnet` and `claude-3-haiku` link:https://docs.an The Anthropic API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. -NOTE: As of April 4th, 2024, streaming is not yet supported for function calling and Tool use is not yet available on third-party platforms like Vertex AI or AWS Bedrock, but is coming soon. +NOTE: As of April 4th, 2024, streaming is not yet supported for function calling and Tool use is not yet available on third-party platforms like Vertex AI, but is coming soon. Spring AI provides flexible and user-friendly ways to register and call custom functions. In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc new file mode 100644 index 0000000000..ca6d7201ac --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc @@ -0,0 +1,187 @@ += Bedrock Anthropic 3 Function Calling + +You can register custom Java functions with the `BedrockAnthropic3ChatModel` and have the Bedrock Anthropic 3 models intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This allows you to connect the LLM capabilities with external tools and APIs. + +The Bedrock Anthropic 3 API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +Spring AI provides flexible and user-friendly ways to register and call custom functions. +In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. +The `description` helps the model to understand when to call the function. + +As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and respond with the result back to the model. +Your function can in turn invoke other 3rd party services to provide the results. + +Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`. + +Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java[FunctionCallbackWrapper.java] utility class to simplify the implementation and registration of Java callback functions. + +== How it works + +Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. + +We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. + +For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function. +The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. + +Spring AI greatly simplifies the code you need to write to support function invocation. +It brokers the function invocation conversation for you. +You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. +You can also reference multiple function bean names in your prompt. + +== Quick Start + +Let's create a chatbot that answer questions by calling our own function. +To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. + +When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON. + +Our function can some SaaS based weather service API and returns the weather response back to the model to complete the conversation. +In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. + +The following `MockWeatherService.java` represents the weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response(30.0, Unit.C); + } +} +---- + +=== Registering Functions as Beans + +With the link:../bedrock/bedrock-anthropic3.html#_auto_configuration[BedrockAnthropic3ChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. + +We start with describing the most POJO friendly options. + +==== Plain Java Functions + +In this approach you define `@Beans` in your application context as you would any other Spring managed object. + +Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallbackWrapper` wrapper that adds the logic for it being invoked via the AI model. +The name of the `@Bean` is passed as a `ChatOption`. + + +[source,java] +---- +@Configuration +static class Config { + + @Bean + @Description("Get the weather in location") // function description + public Function weatherFunction1() { + return new MockWeatherService(); + } + ... +} +---- + +The `@Description` annotation is optional and provides a function description (2) that helps the model understand when to call the function. +It is an important property to set to help the AI model determine what client side function to invoke. + +Another option to provide the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request` to provide the function description: + +[source,java] +---- + +@Configuration +static class Config { + + @Bean + public Function currentWeather3() { // (1) bean name as function name. + return new MockWeatherService(); + } + ... +} + +@JsonClassDescription("Get the weather in location") // (2) function description +public record Request(String location, Unit unit) {} +---- + +It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach. + + +==== FunctionCallback Wrapper + +Another way to register a function is to create a `FunctionCallbackWrapper` wrapper like this: + +[source,java] +---- +@Configuration +static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return new FunctionCallbackWrapper<>("CurrentWeather", // (1) function name + "Get the weather in location", // (2) function description + (response) -> "" + response.temp() + response.unit(), // (3) Response Converter + new MockWeatherService()); // function code + } + ... +} +---- + +It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `BedrockAnthropic3ChatModel`. +It also provides a description (2) and an optional response converter (3) to convert the response into a text as expected by the model. + +NOTE: By default, the response converter does a JSON serialization of the Response object. + +NOTE: The `FunctionCallbackWrapper` internally resolves the function call signature based on the `MockWeatherService.Request` class. + +=== Specifying functions in Chat Options + +To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: + +[source,java] +---- +BedrockAnthropic3ChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + Anthropic3ChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling. + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and produce the final response. + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +BedrockAnthropic3ChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +var promptOptions = Anthropic3ChatOptions.builder() + .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( + "CurrentWeather", // name + "Get the weather in location", // function description + new MockWeatherService()))) // function code + .build(); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java[FunctionCallWithPromptFunctionIT.java] integration test provides a complete example of how to register a function with the `BedrockAnthropic3ChatModel` and use it in a prompt request. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-coherecommandr-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-coherecommandr-chat-functions.adoc new file mode 100644 index 0000000000..abf1410794 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-coherecommandr-chat-functions.adoc @@ -0,0 +1,187 @@ += Bedrock Cohere Command R Function Calling + +You can register custom Java functions with the `BedrockCohereCommandRChatModel` and have the Bedrock Cohere Command R models intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This allows you to connect the LLM capabilities with external tools and APIs. + +The Bedrock Cohere Command R API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +Spring AI provides flexible and user-friendly ways to register and call custom functions. +In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. +The `description` helps the model to understand when to call the function. + +As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and respond with the result back to the model. +Your function can in turn invoke other 3rd party services to provide the results. + +Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`. + +Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java[FunctionCallbackWrapper.java] utility class to simplify the implementation and registration of Java callback functions. + +== How it works + +Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. + +We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. + +For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function. +The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. + +Spring AI greatly simplifies the code you need to write to support function invocation. +It brokers the function invocation conversation for you. +You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. +You can also reference multiple function bean names in your prompt. + +== Quick Start + +Let's create a chatbot that answer questions by calling our own function. +To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. + +When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON. + +Our function can some SaaS based weather service API and returns the weather response back to the model to complete the conversation. +In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. + +The following `MockWeatherService.java` represents the weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response(30.0, Unit.C); + } +} +---- + +=== Registering Functions as Beans + +With the link:../bedrock/bedrock-coherecommandr.html#_auto_configuration[BedrockCohereCommandRChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. + +We start with describing the most POJO friendly options. + +==== Plain Java Functions + +In this approach you define `@Beans` in your application context as you would any other Spring managed object. + +Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallbackWrapper` wrapper that adds the logic for it being invoked via the AI model. +The name of the `@Bean` is passed as a `ChatOption`. + + +[source,java] +---- +@Configuration +static class Config { + + @Bean + @Description("Get the weather in location") // function description + public Function weatherFunction1() { + return new MockWeatherService(); + } + ... +} +---- + +The `@Description` annotation is optional and provides a function description (2) that helps the model understand when to call the function. +It is an important property to set to help the AI model determine what client side function to invoke. + +Another option to provide the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request` to provide the function description: + +[source,java] +---- + +@Configuration +static class Config { + + @Bean + public Function currentWeather3() { // (1) bean name as function name. + return new MockWeatherService(); + } + ... +} + +@JsonClassDescription("Get the weather in location") // (2) function description +public record Request(String location, Unit unit) {} +---- + +It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/coherecommandr/tool/FunctionCallWithFunctionBeanIT.java.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach. + + +==== FunctionCallback Wrapper + +Another way to register a function is to create a `FunctionCallbackWrapper` wrapper like this: + +[source,java] +---- +@Configuration +static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return new FunctionCallbackWrapper<>("CurrentWeather", // (1) function name + "Get the weather in location", // (2) function description + (response) -> "" + response.temp() + response.unit(), // (3) Response Converter + new MockWeatherService()); // function code + } + ... +} +---- + +It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `BedrockCohereCommandRChatModel`. +It also provides a description (2) and an optional response converter (3) to convert the response into a text as expected by the model. + +NOTE: By default, the response converter does a JSON serialization of the Response object. + +NOTE: The `FunctionCallbackWrapper` internally resolves the function call signature based on the `MockWeatherService.Request` class. + +=== Specifying functions in Chat Options + +To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: + +[source,java] +---- +BedrockCohereCommandRChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + BedrockCohereCommandRChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling. + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and produce the final response. + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +BedrockCohereCommandRChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +var promptOptions = BedrockCohereCommandRChatOptions.builder() + .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( + "CurrentWeather", // name + "Get the weather in location", // function description + new MockWeatherService()))) // function code + .build(); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithPromptFunctionIT.java[FunctionCallWithPromptFunctionIT.java] integration test provides a complete example of how to register a function with the `BedrockAnthropic3ChatModel` and use it in a prompt request. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-mistral-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-mistral-chat-functions.adoc new file mode 100644 index 0000000000..3e76c632c7 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-mistral-chat-functions.adoc @@ -0,0 +1,187 @@ += Bedrock Mistral Function Calling + +You can register custom Java functions with the `BedrockMistralChatModel` and have the Bedrock Mistral models intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This allows you to connect the LLM capabilities with external tools and APIs. + +The Bedrock Mistral API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +Spring AI provides flexible and user-friendly ways to register and call custom functions. +In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. +The `description` helps the model to understand when to call the function. + +As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and respond with the result back to the model. +Your function can in turn invoke other 3rd party services to provide the results. + +Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`. + +Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java[FunctionCallbackWrapper.java] utility class to simplify the implementation and registration of Java callback functions. + +== How it works + +Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. + +We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. + +For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function. +The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. + +Spring AI greatly simplifies the code you need to write to support function invocation. +It brokers the function invocation conversation for you. +You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. +You can also reference multiple function bean names in your prompt. + +== Quick Start + +Let's create a chatbot that answer questions by calling our own function. +To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. + +When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON. + +Our function can some SaaS based weather service API and returns the weather response back to the model to complete the conversation. +In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. + +The following `MockWeatherService.java` represents the weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response(30.0, Unit.C); + } +} +---- + +=== Registering Functions as Beans + +With the link:../bedrock/bedrock-mistral.html#_auto_configuration[BedrockMistralChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. + +We start with describing the most POJO friendly options. + +==== Plain Java Functions + +In this approach you define `@Beans` in your application context as you would any other Spring managed object. + +Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallbackWrapper` wrapper that adds the logic for it being invoked via the AI model. +The name of the `@Bean` is passed as a `ChatOption`. + + +[source,java] +---- +@Configuration +static class Config { + + @Bean + @Description("Get the weather in location") // function description + public Function weatherFunction1() { + return new MockWeatherService(); + } + ... +} +---- + +The `@Description` annotation is optional and provides a function description (2) that helps the model understand when to call the function. +It is an important property to set to help the AI model determine what client side function to invoke. + +Another option to provide the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request` to provide the function description: + +[source,java] +---- + +@Configuration +static class Config { + + @Bean + public Function currentWeather3() { // (1) bean name as function name. + return new MockWeatherService(); + } + ... +} + +@JsonClassDescription("Get the weather in location") // (2) function description +public record Request(String location, Unit unit) {} +---- + +It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistral/tool/FunctionCallWithFunctionBeanIT.java.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach. + + +==== FunctionCallback Wrapper + +Another way to register a function is to create a `FunctionCallbackWrapper` wrapper like this: + +[source,java] +---- +@Configuration +static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return new FunctionCallbackWrapper<>("CurrentWeather", // (1) function name + "Get the weather in location", // (2) function description + (response) -> "" + response.temp() + response.unit(), // (3) Response Converter + new MockWeatherService()); // function code + } + ... +} +---- + +It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `BedrockMistralChatModel`. +It also provides a description (2) and an optional response converter (3) to convert the response into a text as expected by the model. + +NOTE: By default, the response converter does a JSON serialization of the Response object. + +NOTE: The `FunctionCallbackWrapper` internally resolves the function call signature based on the `MockWeatherService.Request` class. + +=== Specifying functions in Chat Options + +To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: + +[source,java] +---- +BedrockMistralChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + BedrockMistralChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling. + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and produce the final response. + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +BedrockMistralChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +var promptOptions = BedrockMistralChatOptions.builder() + .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( + "CurrentWeather", // name + "Get the weather in location", // function description + new MockWeatherService()))) // function code + .build(); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistral/tool/FunctionCallWithPromptFunctionIT.java[FunctionCallWithPromptFunctionIT.java] integration test provides a complete example of how to register a function with the `BedrockAnthropic3ChatModel` and use it in a prompt request. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc index 560c3cbdbc..64285b0f94 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc @@ -202,6 +202,7 @@ image::spring-ai-chat-completions-clients.jpg[align="center", width="800px"] ** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] ** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] ** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion] +** xref:api/chat/bedrock/bedrock-mistral.adoc[Mistral Chat Completion] * xref:api/chat/mistralai-chat.adoc[Mistral AI Chat Completion] (streaming & function-calling support) * xref:api/chat/anthropic-chat.adoc[Anthropic Chat Completion] (streaming) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index e959c6065b..f25127db73 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -13,4 +13,7 @@ Spring AI currently supports Function invocation for the following AI Models * Mistral AI: Refer to the xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral AI function invocation docs]. * Anthropic Claude: Refer to the xref:api/chat/functions/anthropic-chat-functions.adoc[Anthropic Claude function invocation docs]. * MiniMax : Refer to the xref:api/chat/functions/minimax-chat-functions.adoc[MiniMax function invocation docs]. -* ZhiPu AI : Refer to the xref:api/chat/functions/zhipuai-chat-functions.adoc[ZhiPu AI function invocation docs]. \ No newline at end of file +* ZhiPu AI : Refer to the xref:api/chat/functions/zhipuai-chat-functions.adoc[ZhiPu AI function invocation docs]. +* Amazon Bedrock Anthropic 3 : Refer to the xref:api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc[Amazon Bedrock Anthropic3 function invocation docs]. +* Amazon Bedrock Mistral : Refer to the xref:api/chat/functions/bedrock/bedrock-mistral-chat-functions.adoc[Amazon Bedrock Mistral function invocation docs]. +* Amazon Bedrock Cohere Command R : Refer to the xref:api/chat/functions/bedrock/bedrock-coherecommandr-chat-functions.adoc[Amazon Bedrock Cohere Command R function invocation docs]. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc index 331143968a..7fd9544f00 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc @@ -255,7 +255,8 @@ The following AI Models have been tested to support List, Map and Bean structure | xref:api/chat/bedrock/bedrock-anthropic.adoc[Bedrock Anthropic 2] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java[BedrockAnthropicChatModelIT.java] | xref:api/chat/bedrock/bedrock-anthropic3.adoc[Bedrock Anthropic 3] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java[BedrockAnthropic3ChatModelIT.java] | xref:api/chat/bedrock/bedrock-cohere.adoc[Bedrock Cohere] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java[BedrockCohereChatModelIT.java] -| xref:api/chat/bedrock/bedrock-llama.adoc[Bedrock Llama] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java[BedrockLlamaChatModelIT.java.java] +| xref:api/chat/bedrock/bedrock-llama.adoc[Bedrock Llama] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java[BedrockLlamaChatModelIT.java] +| xref:api/chat/bedrock/bedrock-mistral.adoc[Bedrock Mistral] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java[BedrockMistralChatModelIT.java] |==== == Build-in JSON mode diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc index fe1bd9b065..55584bf189 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc @@ -149,10 +149,12 @@ Each of the following sections in the documentation shows which dependencies you ** xref:api/chat/vertexai-gemini-chat.adoc[Google Vertex AI Gemini Chat Completion] (streaming, multi-modality & function-calling support) ** xref:api/bedrock.adoc[Amazon Bedrock] *** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere Chat Completion] +*** xref:api/chat/bedrock/bedrock-coherecommandr.adoc[Cohere Command R Chat Completion] *** xref:api/chat/bedrock/bedrock-llama.adoc[Llama Chat Completion] *** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] *** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] *** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion] +*** xref:api/chat/bedrock/bedrock-mistral.adoc[Mistral Chat Completion] ** xref:api/chat/mistralai-chat.adoc[MistralAI Chat Completion] (streaming and function-calling support) === Image Generation Models diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java index 46ee804056..33dd645776 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java @@ -22,7 +22,10 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.providers.AwsRegionProvider; import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; @@ -60,6 +63,32 @@ public AwsRegionProvider regionProvider(BedrockAwsConnectionProperties propertie return DefaultAwsRegionProviderChain.builder().build(); } + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) + public BedrockRuntimeClient bedrockRuntimeClient(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider, BedrockAwsConnectionProperties properties) { + + return BedrockRuntimeClient.builder() + .region(regionProvider.getRegion()) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(c -> c.apiCallTimeout(properties.getTimeout())) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) + public BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider, BedrockAwsConnectionProperties properties) { + + return BedrockRuntimeAsyncClient.builder() + .region(regionProvider.getRegion()) + .credentialsProvider(credentialsProvider) + .overrideConfiguration(c -> c.apiCallTimeout(properties.getTimeout())) + .build(); + } + /** * @author Wei Jiang */ diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java index 3e30324546..2dfb4e3c80 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfiguration.java @@ -15,21 +15,19 @@ */ package org.springframework.ai.autoconfigure.bedrock.anthropic; -import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. @@ -40,28 +38,18 @@ * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration -@ConditionalOnClass(AnthropicChatBedrockApi.class) +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) @EnableConfigurationProperties({ BedrockAnthropicChatProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockAnthropicChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @Import(BedrockAwsConnectionConfiguration.class) public class BedrockAnthropicChatAutoConfiguration { @Bean - @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public AnthropicChatBedrockApi anthropicApi(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockAnthropicChatProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new AnthropicChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); - } - - @Bean - @ConditionalOnBean(AnthropicChatBedrockApi.class) - public BedrockAnthropicChatModel anthropicChatModel(AnthropicChatBedrockApi anthropicApi, + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockAnthropicChatModel anthropicChatModel(BedrockConverseApi converseApi, BedrockAnthropicChatProperties properties) { - return new BedrockAnthropicChatModel(anthropicApi, properties.getOptions()); + return new BedrockAnthropicChatModel(properties.getModel(), converseApi, properties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java index e9b2636773..d1850b2708 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatProperties.java @@ -15,10 +15,8 @@ */ package org.springframework.ai.autoconfigure.bedrock.anthropic; -import java.util.List; - import org.springframework.ai.bedrock.anthropic.AnthropicChatOptions; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel.AnthropicChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -27,6 +25,7 @@ * Configuration properties for Bedrock Anthropic. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @ConfigurationProperties(BedrockAnthropicChatProperties.CONFIG_PREFIX) @@ -46,12 +45,7 @@ public class BedrockAnthropicChatProperties { private String model = AnthropicChatModel.CLAUDE_V2.id(); @NestedConfigurationProperty - private AnthropicChatOptions options = AnthropicChatOptions.builder() - .withTemperature(0.7f) - .withMaxTokensToSample(300) - .withTopK(10) - .withStopSequences(List.of("\n\nHuman:")) - .build(); + private AnthropicChatOptions options = AnthropicChatOptions.builder().build(); public boolean isEnabled() { return this.enabled; @@ -75,7 +69,6 @@ public AnthropicChatOptions getOptions() { public void setOptions(AnthropicChatOptions options) { Assert.notNull(options, "AnthropicChatOptions must not be null"); - Assert.notNull(options.getTemperature(), "AnthropicChatOptions.temperature must not be null"); this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java index 3e53f026b2..f25dee674f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfiguration.java @@ -15,24 +15,30 @@ */ package org.springframework.ai.autoconfigure.bedrock.anthropic3; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.List; + import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.util.CollectionUtils; + import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** - * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic Chat Client. + * {@link AutoConfiguration Auto-configuration} for Bedrock Anthropic3 Chat Client. * * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. * @@ -40,28 +46,32 @@ * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration -@ConditionalOnClass(Anthropic3ChatBedrockApi.class) +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) @EnableConfigurationProperties({ BedrockAnthropic3ChatProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockAnthropic3ChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @Import(BedrockAwsConnectionConfiguration.class) public class BedrockAnthropic3ChatAutoConfiguration { @Bean - @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public Anthropic3ChatBedrockApi anthropic3Api(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockAnthropic3ChatProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new Anthropic3ChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockAnthropic3ChatModel anthropic3ChatModel(BedrockConverseApi converseApi, + BedrockAnthropic3ChatProperties properties, FunctionCallbackContext functionCallbackContext, + List toolFunctionCallbacks) { + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + properties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new BedrockAnthropic3ChatModel(properties.getModel(), converseApi, properties.getOptions(), + functionCallbackContext); } @Bean - @ConditionalOnBean(Anthropic3ChatBedrockApi.class) - public BedrockAnthropic3ChatModel anthropic3ChatModel(Anthropic3ChatBedrockApi anthropicApi, - BedrockAnthropic3ChatProperties properties) { - return new BedrockAnthropic3ChatModel(anthropicApi, properties.getOptions()); + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java index 71086b0d66..413a209a81 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatProperties.java @@ -15,9 +15,8 @@ */ package org.springframework.ai.autoconfigure.bedrock.anthropic3; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; import org.springframework.ai.bedrock.anthropic3.Anthropic3ChatOptions; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel.Anthropic3ChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -26,6 +25,7 @@ * Configuration properties for Bedrock Anthropic Claude 3. * * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ @ConfigurationProperties(BedrockAnthropic3ChatProperties.CONFIG_PREFIX) @@ -39,19 +39,13 @@ public class BedrockAnthropic3ChatProperties { private boolean enabled = false; /** - * The generative id to use. See the {@link AnthropicChatModel} for the supported + * The generative id to use. See the {@link Anthropic3ChatModel} for the supported * models. */ - private String model = AnthropicChatModel.CLAUDE_V3_SONNET.id(); + private String model = Anthropic3ChatModel.CLAUDE_V3_SONNET.id(); @NestedConfigurationProperty - private Anthropic3ChatOptions options = Anthropic3ChatOptions.builder() - .withTemperature(0.7f) - .withMaxTokens(300) - .withTopK(10) - .withAnthropicVersion(Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) - // .withStopSequences(List.of("\n\nHuman:")) - .build(); + private Anthropic3ChatOptions options = Anthropic3ChatOptions.builder().build(); public boolean isEnabled() { return this.enabled; @@ -74,8 +68,7 @@ public Anthropic3ChatOptions getOptions() { } public void setOptions(Anthropic3ChatOptions options) { - Assert.notNull(options, "AnthropicChatOptions must not be null"); - Assert.notNull(options.getTemperature(), "AnthropicChatOptions.temperature must not be null"); + Assert.notNull(options, "Anthropic3ChatOptions must not be null"); this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfiguration.java new file mode 100644 index 0000000000..b287c670d8 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfiguration.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.api; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; + +/** + * {@link AutoConfiguration Auto-configuration} for Bedrock Converse API. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@AutoConfiguration(after = SpringAiRetryAutoConfiguration.class) +@EnableConfigurationProperties({ BedrockAwsConnectionProperties.class }) +@ConditionalOnClass({ BedrockConverseApi.class, BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) +@Import(BedrockAwsConnectionConfiguration.class) +public class BedrockConverseApiAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) + public BedrockConverseApi bedrockConverseApi(BedrockRuntimeClient bedrockRuntimeClient, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, RetryTemplate retryTemplate) { + return new BedrockConverseApi(bedrockRuntimeClient, bedrockRuntimeAsyncClient, retryTemplate); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java index 896078e5bc..aa609227d4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfiguration.java @@ -15,52 +15,42 @@ */ package org.springframework.ai.autoconfigure.bedrock.cohere; -import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Chat Client. * + * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. + * * @author Christian Tzolov * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration -@ConditionalOnClass(CohereChatBedrockApi.class) +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) @EnableConfigurationProperties({ BedrockCohereChatProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockCohereChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @Import(BedrockAwsConnectionConfiguration.class) public class BedrockCohereChatAutoConfiguration { @Bean - @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public CohereChatBedrockApi cohereChatApi(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockCohereChatProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new CohereChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); - } - - @Bean - @ConditionalOnBean(CohereChatBedrockApi.class) - public BedrockCohereChatModel cohereChatModel(CohereChatBedrockApi cohereChatApi, + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockCohereChatModel cohereChatModel(BedrockConverseApi converseApi, BedrockCohereChatProperties properties) { - return new BedrockCohereChatModel(cohereChatApi, properties.getOptions()); + return new BedrockCohereChatModel(properties.getModel(), converseApi, properties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java index 0381d591b9..ea1766d9a5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatProperties.java @@ -16,14 +16,16 @@ package org.springframework.ai.autoconfigure.bedrock.cohere; import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel.CohereChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; /** * Bedrock Cohere Chat autoconfiguration properties. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @ConfigurationProperties(BedrockCohereChatProperties.CONFIG_PREFIX) @@ -39,7 +41,7 @@ public class BedrockCohereChatProperties { /** * Bedrock Cohere Chat generative name. Defaults to 'cohere-command-v14'. */ - private String model = CohereChatBedrockApi.CohereChatModel.COHERE_COMMAND_V14.id(); + private String model = CohereChatModel.COHERE_COMMAND_V14.id(); @NestedConfigurationProperty private BedrockCohereChatOptions options = BedrockCohereChatOptions.builder().build(); @@ -65,6 +67,8 @@ public BedrockCohereChatOptions getOptions() { } public void setOptions(BedrockCohereChatOptions options) { + Assert.notNull(options, "BedrockCohereChatOptions must not be null"); + this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfiguration.java new file mode 100644 index 0000000000..48e4de13f2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfiguration.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere; + +import java.util.List; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.util.CollectionUtils; + +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; + +/** + * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Command R Chat Client. + * + * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) +@EnableConfigurationProperties({ BedrockCohereCommandRChatProperties.class, BedrockAwsConnectionProperties.class }) +@ConditionalOnProperty(prefix = BedrockCohereCommandRChatProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true") +@Import(BedrockAwsConnectionConfiguration.class) +public class BedrockCohereCommandRChatAutoConfiguration { + + @Bean + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockCohereCommandRChatModel cohereCommandRChatModel(BedrockConverseApi converseApi, + BedrockCohereCommandRChatProperties properties, FunctionCallbackContext functionCallbackContext, + List toolFunctionCallbacks) { + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + properties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new BedrockCohereCommandRChatModel(properties.getModel(), converseApi, properties.getOptions(), + functionCallbackContext); + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatProperties.java new file mode 100644 index 0000000000..96ef609f2e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatProperties.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere; + +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel.CohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Bedrock Cohere Command R Chat autoconfiguration properties. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@ConfigurationProperties(BedrockCohereCommandRChatProperties.CONFIG_PREFIX) +public class BedrockCohereCommandRChatProperties { + + public static final String CONFIG_PREFIX = "spring.ai.bedrock.coherecommandr.chat"; + + /** + * Enable Bedrock Cohere Command R Chat Client. False by default. + */ + private boolean enabled = false; + + /** + * Bedrock Cohere Command R Chat generative name. Defaults to + * 'cohere.command-r-plus-v1:0'. + */ + private String model = CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(); + + @NestedConfigurationProperty + private BedrockCohereCommandRChatOptions options = BedrockCohereCommandRChatOptions.builder().build(); + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public BedrockCohereCommandRChatOptions getOptions() { + return this.options; + } + + public void setOptions(BedrockCohereCommandRChatOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java index dabc8126e2..026d03559c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfiguration.java @@ -16,11 +16,12 @@ package org.springframework.ai.autoconfigure.bedrock.cohere; import com.fasterxml.jackson.databind.ObjectMapper; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -31,6 +32,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Embedding Model. @@ -39,7 +41,7 @@ * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration +@AutoConfiguration(after = SpringAiRetryAutoConfiguration.class) @ConditionalOnClass(CohereEmbeddingBedrockApi.class) @EnableConfigurationProperties({ BedrockCohereEmbeddingProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockCohereEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @@ -48,21 +50,20 @@ public class BedrockCohereEmbeddingAutoConfiguration { @Bean @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public CohereEmbeddingBedrockApi cohereEmbeddingApi(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockCohereEmbeddingProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new CohereEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); + @ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) + public CohereEmbeddingBedrockApi cohereEmbeddingApi(BedrockCohereEmbeddingProperties properties, + BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) { + return new CohereEmbeddingBedrockApi(properties.getModel(), bedrockRuntimeClient, bedrockRuntimeAsyncClient, + new ObjectMapper()); } @Bean @ConditionalOnMissingBean @ConditionalOnBean(CohereEmbeddingBedrockApi.class) public BedrockCohereEmbeddingModel cohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingApi, - BedrockCohereEmbeddingProperties properties) { + BedrockCohereEmbeddingProperties properties, RetryTemplate retryTemplate) { - return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, properties.getOptions()); + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, properties.getOptions(), retryTemplate); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java index 8ad3c0bb1a..6cb50096a6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatAutoConfiguration.java @@ -16,31 +16,31 @@ package org.springframework.ai.autoconfigure.bedrock.jurrasic2; -import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Jurassic2 Chat Client. * + * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. + * * @author Ahmed Yousri * @author Wei Jiang * @since 1.0.0 */ -@AutoConfiguration -@ConditionalOnClass(Ai21Jurassic2ChatBedrockApi.class) +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) @EnableConfigurationProperties({ BedrockAi21Jurassic2ChatProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockAi21Jurassic2ChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @@ -48,23 +48,11 @@ public class BedrockAi21Jurassic2ChatAutoConfiguration { @Bean - @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockAi21Jurassic2ChatProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new Ai21Jurassic2ChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); - } - - @Bean - @ConditionalOnBean(Ai21Jurassic2ChatBedrockApi.class) - public BedrockAi21Jurassic2ChatModel jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi ai21Jurassic2ChatBedrockApi, + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockAi21Jurassic2ChatModel jurassic2ChatModel(BedrockConverseApi converseApi, BedrockAi21Jurassic2ChatProperties properties) { - return BedrockAi21Jurassic2ChatModel.builder(ai21Jurassic2ChatBedrockApi) - .withOptions(properties.getOptions()) - .build(); + return new BedrockAi21Jurassic2ChatModel(properties.getModel(), converseApi, properties.getOptions()); } } \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java index eccd7e0c9e..a29656e958 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/jurrasic2/BedrockAi21Jurassic2ChatProperties.java @@ -17,14 +17,16 @@ package org.springframework.ai.autoconfigure.bedrock.jurrasic2; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatOptions; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel; +import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel.Ai21Jurassic2ChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; /** * Configuration properties for Bedrock Ai21Jurassic2. * * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ @ConfigurationProperties(BedrockAi21Jurassic2ChatProperties.CONFIG_PREFIX) @@ -44,10 +46,7 @@ public class BedrockAi21Jurassic2ChatProperties { private String model = Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(); @NestedConfigurationProperty - private BedrockAi21Jurassic2ChatOptions options = BedrockAi21Jurassic2ChatOptions.builder() - .withTemperature(0.7f) - .withMaxTokens(500) - .build(); + private BedrockAi21Jurassic2ChatOptions options = BedrockAi21Jurassic2ChatOptions.builder().build(); public boolean isEnabled() { return this.enabled; @@ -70,6 +69,8 @@ public BedrockAi21Jurassic2ChatOptions getOptions() { } public void setOptions(BedrockAi21Jurassic2ChatOptions options) { + Assert.notNull(options, "BedrockAi21Jurassic2ChatOptions must not be null"); + this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java index 6e105b8f26..5f2719b03a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfiguration.java @@ -15,18 +15,16 @@ */ package org.springframework.ai.autoconfigure.bedrock.llama; -import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; @@ -41,27 +39,18 @@ * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration -@ConditionalOnClass(LlamaChatBedrockApi.class) +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) @EnableConfigurationProperties({ BedrockLlamaChatProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockLlamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @Import(BedrockAwsConnectionConfiguration.class) public class BedrockLlamaChatAutoConfiguration { @Bean - @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public LlamaChatBedrockApi llamaApi(AwsCredentialsProvider credentialsProvider, AwsRegionProvider regionProvider, - BedrockLlamaChatProperties properties, BedrockAwsConnectionProperties awsProperties) { - return new LlamaChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); - } - - @Bean - @ConditionalOnBean(LlamaChatBedrockApi.class) - public BedrockLlamaChatModel llamaChatModel(LlamaChatBedrockApi llamaApi, BedrockLlamaChatProperties properties) { + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockLlamaChatModel llamaChatModel(BedrockConverseApi converseApi, BedrockLlamaChatProperties properties) { - return new BedrockLlamaChatModel(llamaApi, properties.getOptions()); + return new BedrockLlamaChatModel(properties.getModel(), converseApi, properties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java index 048b7dde2b..8f3eb3d666 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatProperties.java @@ -16,14 +16,16 @@ package org.springframework.ai.autoconfigure.bedrock.llama; import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel.LlamaChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; /** * Configuration properties for Bedrock Llama. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @ConfigurationProperties(BedrockLlamaChatProperties.CONFIG_PREFIX) @@ -42,10 +44,7 @@ public class BedrockLlamaChatProperties { private String model = LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(); @NestedConfigurationProperty - private BedrockLlamaChatOptions options = BedrockLlamaChatOptions.builder() - .withTemperature(0.7f) - .withMaxGenLen(300) - .build(); + private BedrockLlamaChatOptions options = BedrockLlamaChatOptions.builder().build(); public boolean isEnabled() { return this.enabled; @@ -68,6 +67,8 @@ public BedrockLlamaChatOptions getOptions() { } public void setOptions(BedrockLlamaChatOptions options) { + Assert.notNull(options, "BedrockLlamaChatOptions must not be null"); + this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfiguration.java new file mode 100644 index 0000000000..cafda0096b --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfiguration.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral; + +import java.util.List; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.util.CollectionUtils; + +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; + +/** + * {@link AutoConfiguration Auto-configuration} for Bedrock Mistral Chat Client. + * + * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) +@EnableConfigurationProperties({ BedrockMistralChatProperties.class, BedrockAwsConnectionProperties.class }) +@ConditionalOnProperty(prefix = BedrockMistralChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") +@Import(BedrockAwsConnectionConfiguration.class) +public class BedrockMistralChatAutoConfiguration { + + @Bean + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockMistralChatModel mistralChatModel(BedrockConverseApi converseApi, + BedrockMistralChatProperties properties, FunctionCallbackContext functionCallbackContext, + List toolFunctionCallbacks) { + if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { + properties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); + } + + return new BedrockMistralChatModel(properties.getModel(), converseApi, properties.getOptions(), + functionCallbackContext); + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatProperties.java new file mode 100644 index 0000000000..ecd88b8cd2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatProperties.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral; + +import org.springframework.ai.bedrock.mistral.BedrockMistralChatOptions; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel.MistralChatModel; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Configuration properties for Bedrock Mistral. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@ConfigurationProperties(BedrockMistralChatProperties.CONFIG_PREFIX) +public class BedrockMistralChatProperties { + + public static final String CONFIG_PREFIX = "spring.ai.bedrock.mistral.chat"; + + /** + * Enable Bedrock Mistral Chat Client. False by default. + */ + private boolean enabled = false; + + /** + * Bedrock Mistral Chat generative name. Defaults to + * 'mistral.mistral-large-2402-v1:0'. + */ + private String model = MistralChatModel.MISTRAL_LARGE.id(); + + @NestedConfigurationProperty + private BedrockMistralChatOptions options = BedrockMistralChatOptions.builder().build(); + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public BedrockMistralChatOptions getOptions() { + return this.options; + } + + public void setOptions(BedrockMistralChatOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java index 0115967fe5..39c0730c19 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfiguration.java @@ -15,52 +15,41 @@ */ package org.springframework.ai.autoconfigure.bedrock.titan; -import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Chat Client. * + * Leverages the Spring Cloud AWS to resolve the {@link AwsCredentialsProvider}. + * * @author Christian Tzolov * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration -@ConditionalOnClass(TitanChatBedrockApi.class) +@AutoConfiguration(after = BedrockConverseApiAutoConfiguration.class) +@ConditionalOnClass(BedrockConverseApi.class) @EnableConfigurationProperties({ BedrockTitanChatProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockTitanChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @Import(BedrockAwsConnectionConfiguration.class) public class BedrockTitanChatAutoConfiguration { @Bean - @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public TitanChatBedrockApi titanChatBedrockApi(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockTitanChatProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new TitanChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); - } - - @Bean - @ConditionalOnBean(TitanChatBedrockApi.class) - public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanChatApi, - BedrockTitanChatProperties properties) { + @ConditionalOnBean(BedrockConverseApi.class) + public BedrockTitanChatModel titanChatModel(BedrockConverseApi converseApi, BedrockTitanChatProperties properties) { - return new BedrockTitanChatModel(titanChatApi, properties.getOptions()); + return new BedrockTitanChatModel(properties.getModel(), converseApi, properties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java index b196e9797a..56afb7462b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatProperties.java @@ -16,14 +16,16 @@ package org.springframework.ai.autoconfigure.bedrock.titan; import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.bedrock.titan.BedrockTitanChatModel.TitanChatModel; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; /** * Bedrock Titan Chat autoconfiguration properties. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @ConfigurationProperties(BedrockTitanChatProperties.CONFIG_PREFIX) @@ -42,7 +44,7 @@ public class BedrockTitanChatProperties { private String model = TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(); @NestedConfigurationProperty - private BedrockTitanChatOptions options = BedrockTitanChatOptions.builder().withTemperature(0.7f).build(); + private BedrockTitanChatOptions options = BedrockTitanChatOptions.builder().build(); public boolean isEnabled() { return enabled; @@ -65,6 +67,8 @@ public BedrockTitanChatOptions getOptions() { } public void setOptions(BedrockTitanChatOptions options) { + Assert.notNull(options, "BedrockTitanChatOptions must not be null"); + this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java index bfca436bf1..3620930cea 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfiguration.java @@ -16,11 +16,12 @@ package org.springframework.ai.autoconfigure.bedrock.titan; import com.fasterxml.jackson.databind.ObjectMapper; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.providers.AwsRegionProvider; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -31,6 +32,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; /** * {@link AutoConfiguration Auto-configuration} for Bedrock Titan Embedding Model. @@ -39,7 +41,7 @@ * @author Wei Jiang * @since 0.8.0 */ -@AutoConfiguration +@AutoConfiguration(after = SpringAiRetryAutoConfiguration.class) @ConditionalOnClass(TitanEmbeddingBedrockApi.class) @EnableConfigurationProperties({ BedrockTitanEmbeddingProperties.class, BedrockAwsConnectionProperties.class }) @ConditionalOnProperty(prefix = BedrockTitanEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") @@ -48,21 +50,21 @@ public class BedrockTitanEmbeddingAutoConfiguration { @Bean @ConditionalOnMissingBean - @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) - public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider credentialsProvider, - AwsRegionProvider regionProvider, BedrockTitanEmbeddingProperties properties, - BedrockAwsConnectionProperties awsProperties) { - return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), - new ObjectMapper(), awsProperties.getTimeout()); + @ConditionalOnBean({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) + public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(BedrockTitanEmbeddingProperties properties, + BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) { + return new TitanEmbeddingBedrockApi(properties.getModel(), bedrockRuntimeClient, bedrockRuntimeAsyncClient, + new ObjectMapper()); } @Bean @ConditionalOnMissingBean @ConditionalOnBean(TitanEmbeddingBedrockApi.class) public BedrockTitanEmbeddingModel titanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingApi, - BedrockTitanEmbeddingProperties properties) { + BedrockTitanEmbeddingProperties properties, RetryTemplate retryTemplate) { - return new BedrockTitanEmbeddingModel(titanEmbeddingApi).withInputType(properties.getInputType()); + return new BedrockTitanEmbeddingModel(titanEmbeddingApi, retryTemplate) + .withInputType(properties.getInputType()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 804d5c8fa0..898e44f276 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -5,14 +5,17 @@ org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingModelAuto org.springframework.ai.autoconfigure.huggingface.HuggingfaceChatAutoConfiguration org.springframework.ai.autoconfigure.vertexai.palm2.VertexAiPalm2AutoConfiguration org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfiguration +org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.llama.BedrockLlamaChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereChatAutoConfiguration +org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereCommandRChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereEmbeddingAutoConfiguration org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.anthropic3.BedrockAnthropic3ChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanEmbeddingAutoConfiguration +org.springframework.ai.autoconfigure.bedrock.mistral.BedrockMistralChatAutoConfiguration org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.oracle.OracleVectorStoreAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java index bea58ce80e..801b832d69 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java @@ -31,6 +31,8 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.providers.AwsRegionProvider; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; /** * @author Wei Jiang @@ -87,6 +89,39 @@ public void autoConfigureWithCustomAWSCredentialAndRegionProvider() { }); } + @Test + public void autoConfigureBedrockClients() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) + .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class)) + .run((context) -> { + var bedrockRuntimeClient = context.getBean(BedrockRuntimeClient.class); + var bedrockRuntimeAsyncClient = context.getBean(BedrockRuntimeAsyncClient.class); + + assertThat(bedrockRuntimeClient).isNotNull(); + assertThat(bedrockRuntimeAsyncClient).isNotNull(); + }); + } + + @Test + public void autoConfigureWithCustomBedrockClients() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) + .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class, + CustomBedrockRuntimeClientAutoConfiguration.class)) + .run((context) -> { + var bedrockRuntimeClient = context.getBean(BedrockRuntimeClient.class); + var bedrockRuntimeAsyncClient = context.getBean(BedrockRuntimeAsyncClient.class); + + assertThat(bedrockRuntimeClient).isNotNull(); + assertThat(bedrockRuntimeAsyncClient).isNotNull(); + }); + } + @EnableConfigurationProperties({ BedrockAwsConnectionProperties.class }) @Import(BedrockAwsConnectionConfiguration.class) static class TestAutoConfiguration { @@ -136,4 +171,29 @@ public Region getRegion() { } + @AutoConfiguration + static class CustomBedrockRuntimeClientAutoConfiguration { + + @Bean + public BedrockRuntimeClient bedrockRuntimeClient(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider) { + + return BedrockRuntimeClient.builder() + .region(regionProvider.getRegion()) + .credentialsProvider(credentialsProvider) + .build(); + } + + @Bean + public BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider) { + + return BedrockRuntimeAsyncClient.builder() + .region(regionProvider.getRegion()) + .credentialsProvider(credentialsProvider) + .build(); + } + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 4137e33ce6..293fe48c49 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -22,12 +22,14 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel; +import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatModel.AnthropicChatModel; import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -41,6 +43,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -51,10 +54,11 @@ public class BedrockAnthropicChatAutoConfigurationIT { .withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true", "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), - "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), "spring.ai.bedrock.anthropic.chat.model=" + AnthropicChatModel.CLAUDE_V2.id(), "spring.ai.bedrock.anthropic.chat.options.temperature=0.5") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropicChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -106,7 +110,8 @@ public void propertiesTest() { "spring.ai.bedrock.anthropic.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic.chat.options.temperature=0.55") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropicChatAutoConfiguration.class)) .run(context -> { var anthropicChatProperties = context.getBean(BedrockAnthropicChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -127,7 +132,8 @@ public void chatCompletionDisabled() { // It is disabled by default new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropicChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropicChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockAnthropicChatModel.class)).isEmpty(); @@ -135,7 +141,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropicChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropicChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockAnthropicChatModel.class)).isNotEmpty(); @@ -143,7 +150,8 @@ public void chatCompletionDisabled() { // Explicitly disable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.anthropic.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockAnthropicChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropicChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropicChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockAnthropicChatModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java index 3defe79b3b..8976cfcdc3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/BedrockAnthropic3ChatAutoConfigurationIT.java @@ -22,12 +22,14 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel.Anthropic3ChatModel; import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -41,6 +43,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -52,9 +55,10 @@ public class BedrockAnthropic3ChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.anthropic3.chat.model=" + AnthropicChatModel.CLAUDE_V3_SONNET.id(), + "spring.ai.bedrock.anthropic3.chat.model=" + Anthropic3ChatModel.CLAUDE_V3_SONNET.id(), "spring.ai.bedrock.anthropic3.chat.options.temperature=0.5") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -106,7 +110,8 @@ public void propertiesTest() { "spring.ai.bedrock.anthropic3.chat.model=MODEL_XYZ", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.anthropic3.chat.options.temperature=0.55") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)) .run(context -> { var anthropicChatProperties = context.getBean(BedrockAnthropic3ChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -127,7 +132,8 @@ public void chatCompletionDisabled() { // It is disabled by default new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropic3ChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockAnthropic3ChatModel.class)).isEmpty(); @@ -135,7 +141,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.anthropic3.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropic3ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockAnthropic3ChatModel.class)).isNotEmpty(); @@ -143,7 +150,8 @@ public void chatCompletionDisabled() { // Explicitly disable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.anthropic3.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockAnthropic3ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAnthropic3ChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockAnthropic3ChatModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/FunctionCallWithFunctionBeanIT.java new file mode 100644 index 0000000000..1770adad78 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/FunctionCallWithFunctionBeanIT.java @@ -0,0 +1,113 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.anthropic3.tool; + +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.bedrock.anthropic3.BedrockAnthropic3ChatAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.anthropic3.tool.MockWeatherService.Request; +import org.springframework.ai.autoconfigure.bedrock.anthropic3.tool.MockWeatherService.Response; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.anthropic3.Anthropic3ChatOptions; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel.Anthropic3ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; + +import software.amazon.awssdk.regions.Region; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +class FunctionCallWithFunctionBeanIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallWithFunctionBeanIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.anthropic3.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.anthropic3.chat.model=" + Anthropic3ChatModel.CLAUDE_V3_SONNET.id(), + "spring.ai.bedrock.anthropic3.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + + contextRunner.run(context -> { + + BedrockAnthropic3ChatModel chatModel = context.getBean(BedrockAnthropic3ChatModel.class); + + var userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + Anthropic3ChatOptions.builder().withFunction("weatherFunction").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + response = chatModel.call(new Prompt(List.of(userMessage), + Anthropic3ChatOptions.builder().withFunction("weatherFunction3").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + }); + } + + @Configuration + static class Config { + + @Bean + @Description("Get the weather in location. Return temperature in 36°F or 36°C format.") + public Function weatherFunction() { + return new MockWeatherService(); + } + + // Relies on the Request's JsonClassDescription annotation to provide the + // function description. + @Bean + public Function weatherFunction3() { + MockWeatherService weatherService = new MockWeatherService(); + return (weatherService::apply); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/FunctionCallWithPromptFunctionIT.java new file mode 100644 index 0000000000..571a324f4c --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/FunctionCallWithPromptFunctionIT.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.anthropic3.tool; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.autoconfigure.bedrock.anthropic3.BedrockAnthropic3ChatAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.anthropic3.tool.FunctionCallWithFunctionBeanIT.Config; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.anthropic3.Anthropic3ChatOptions; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatModel.Anthropic3ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import software.amazon.awssdk.regions.Region; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class FunctionCallWithPromptFunctionIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallWithPromptFunctionIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.anthropic3.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.anthropic3.chat.model=" + Anthropic3ChatModel.CLAUDE_V3_SONNET.id(), + "spring.ai.bedrock.anthropic3.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAnthropic3ChatAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + contextRunner.run(context -> { + + BedrockAnthropic3ChatModel chatModel = context.getBean(BedrockAnthropic3ChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + var promptOptions = Anthropic3ChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/MockWeatherService.java new file mode 100644 index 0000000000..711f642e55 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic3/tool/MockWeatherService.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.anthropic3.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * Mock 3rd party weather service. + * + * @author Wei Jiang + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temperature, double feels_like, double temp_min, double temp_max, int pressure, + int humidity, Unit unit) { + } + + @Override + public Response apply(Request request) { + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfigurationIT.java new file mode 100644 index 0000000000..1d4c42c77d --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/api/BedrockConverseApiAutoConfigurationIT.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.api; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockConverseApiAutoConfigurationIT { + + @Test + public void autoConfigureBedrockConverseApi() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class)) + .run((context) -> { + var bedrockConverseApi = context.getBean(BedrockConverseApi.class); + + assertThat(bedrockConverseApi).isNotNull(); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index 83b487c901..f419748eb2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -16,25 +16,24 @@ package org.springframework.ai.autoconfigure.bedrock.cohere; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatModel.CohereChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions.ReturnLikelihoods; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatOptions.Truncate; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods; -import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -43,6 +42,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -57,14 +57,8 @@ public class BedrockCohereChatAutoConfigurationIT { "spring.ai.bedrock.cohere.chat.model=" + CohereChatModel.COHERE_COMMAND_V14.id(), "spring.ai.bedrock.cohere.chat.options.temperature=0.5", "spring.ai.bedrock.cohere.chat.options.maxTokens=500") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)); - - private final Message systemMessage = new SystemPromptTemplate(""" - You are a helpful AI assistant. Your name is {name}. - You are an AI assistant that helps people find information. - Your name is {name} - You should reply to the user's request with your name and also in the style of a {voice}. - """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereChatAutoConfiguration.class)); private final UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); @@ -73,7 +67,7 @@ public class BedrockCohereChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockCohereChatModel cohereChatModel = context.getBean(BedrockCohereChatModel.class); - ChatResponse response = cohereChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = cohereChatModel.call(new Prompt(List.of(userMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -84,7 +78,7 @@ public void chatCompletionStreaming() { BedrockCohereChatModel cohereChatModel = context.getBean(BedrockCohereChatModel.class); - Flux response = cohereChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = cohereChatModel.stream(new Prompt(List.of(userMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); @@ -115,7 +109,8 @@ public void propertiesTest() { "spring.ai.bedrock.cohere.chat.options.numGenerations=3", "spring.ai.bedrock.cohere.chat.options.truncate=START", "spring.ai.bedrock.cohere.chat.options.maxTokens=123") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereChatAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(BedrockCohereChatProperties.class); var aswProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -143,7 +138,8 @@ public void chatCompletionDisabled() { // It is disabled by default new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockCohereChatModel.class)).isEmpty(); @@ -151,7 +147,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockCohereChatModel.class)).isNotEmpty(); @@ -159,7 +156,8 @@ public void chatCompletionDisabled() { // Explicitly disable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockCohereChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockCohereChatModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfigurationIT.java new file mode 100644 index 0000000000..d6fd1309d8 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfigurationIT.java @@ -0,0 +1,190 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel.CohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatOptions.PromptTruncation; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockCohereCommandRChatAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.coherecommandr.chat.model=" + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + "spring.ai.bedrock.coherecommandr.chat.options.temperature=0.5", + "spring.ai.bedrock.coherecommandr.chat.options.maxTokens=500") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)); + + private final Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + private final UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + @Test + public void chatCompletion() { + contextRunner.run(context -> { + BedrockCohereCommandRChatModel cohereCommandRChatModel = context + .getBean(BedrockCohereCommandRChatModel.class); + ChatResponse response = cohereCommandRChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + }); + } + + @Test + public void chatCompletionStreaming() { + contextRunner.run(context -> { + + BedrockCohereCommandRChatModel cohereCommandRChatModel = context + .getBean(BedrockCohereCommandRChatModel.class); + + Flux response = cohereCommandRChatModel + .stream(new Prompt(List.of(userMessage, systemMessage))); + + List responses = response.collectList().block(); + assertThat(responses.size()).isGreaterThan(2); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + }); + } + + @Test + public void propertiesTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", + "spring.ai.bedrock.coherecommandr.chat.model=MODEL_XYZ", + "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), + "spring.ai.bedrock.coherecommandr.chat.options.searchQueriesOnly=true", + "spring.ai.bedrock.coherecommandr.chat.options.preamble=preamble", + "spring.ai.bedrock.coherecommandr.chat.options.maxTokens=123", + "spring.ai.bedrock.coherecommandr.chat.options.temperature=0.55", + "spring.ai.bedrock.coherecommandr.chat.options.topP=0.55", + "spring.ai.bedrock.coherecommandr.chat.options.topK=10", + "spring.ai.bedrock.coherecommandr.chat.options.promptTruncation=AUTO_PRESERVE_ORDER", + "spring.ai.bedrock.coherecommandr.chat.options.frequencyPenalty=0.55", + "spring.ai.bedrock.coherecommandr.chat.options.presencePenalty=0.66", + "spring.ai.bedrock.coherecommandr.chat.options.seed=555555", + "spring.ai.bedrock.coherecommandr.chat.options.returnPrompt=true", + "spring.ai.bedrock.coherecommandr.chat.options.stopSequences=END1,END2", + "spring.ai.bedrock.coherecommandr.chat.options.rawPrompting=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(BedrockCohereCommandRChatProperties.class); + var aswProperties = context.getBean(BedrockAwsConnectionProperties.class); + + assertThat(chatProperties.isEnabled()).isTrue(); + assertThat(aswProperties.getRegion()).isEqualTo(Region.EU_CENTRAL_1.id()); + assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); + + assertThat(chatProperties.getOptions().getSearchQueriesOnly()).isTrue(); + assertThat(chatProperties.getOptions().getPreamble()).isEqualTo("preamble"); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopK()).isEqualTo(10); + assertThat(chatProperties.getOptions().getPromptTruncation()) + .isEqualTo(PromptTruncation.AUTO_PRESERVE_ORDER); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0.66f); + assertThat(chatProperties.getOptions().getSeed()).isEqualTo(555555); + assertThat(chatProperties.getOptions().getReturnPrompt()).isTrue(); + assertThat(chatProperties.getOptions().getStopSequences()).isEqualTo(List.of("END1", "END2")); + assertThat(chatProperties.getOptions().getRawPrompting()).isTrue(); + + assertThat(aswProperties.getAccessKey()).isEqualTo("ACCESS_KEY"); + assertThat(aswProperties.getSecretKey()).isEqualTo("SECRET_KEY"); + }); + } + + @Test + public void chatCompletionDisabled() { + + // It is disabled by default + new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockCohereCommandRChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockCohereCommandRChatModel.class)).isEmpty(); + }); + + // Explicitly enable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockCohereCommandRChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(BedrockCohereCommandRChatModel.class)).isNotEmpty(); + }); + + // Explicitly disable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockCohereCommandRChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockCohereCommandRChatModel.class)).isEmpty(); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java index 14d3889551..6c0a9f8a10 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.cohere.BedrockCohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; @@ -47,7 +48,8 @@ public class BedrockCohereEmbeddingAutoConfigurationIT { "spring.ai.bedrock.cohere.embedding.model=" + CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(), "spring.ai.bedrock.cohere.embedding.options.inputType=SEARCH_DOCUMENT", "spring.ai.bedrock.cohere.embedding.options.truncate=NONE") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)); @Test public void singleEmbedding() { @@ -91,7 +93,8 @@ public void propertiesTest() { "spring.ai.bedrock.cohere.embedding.model=MODEL_XYZ", "spring.ai.bedrock.cohere.embedding.options.inputType=CLASSIFICATION", "spring.ai.bedrock.cohere.embedding.options.truncate=START") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)) .run(context -> { var properties = context.getBean(BedrockCohereEmbeddingProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -113,7 +116,8 @@ public void embeddingDisabled() { // It is disabled by default new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereEmbeddingProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockCohereEmbeddingModel.class)).isEmpty(); @@ -121,7 +125,8 @@ public void embeddingDisabled() { // Explicitly enable the embedding auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.embedding.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockCohereEmbeddingModel.class)).isNotEmpty(); @@ -129,7 +134,8 @@ public void embeddingDisabled() { // Explicitly disable the embedding auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.cohere.embedding.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockCohereEmbeddingProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockCohereEmbeddingModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithFunctionBeanIT.java new file mode 100644 index 0000000000..cbcf63b444 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithFunctionBeanIT.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereCommandRChatAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.cohere.tool.MockWeatherService.Request; +import org.springframework.ai.autoconfigure.bedrock.cohere.tool.MockWeatherService.Response; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel.CohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatOptions; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; + +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +class FunctionCallWithFunctionBeanIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallWithFunctionBeanIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.coherecommandr.chat.model=" + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + "spring.ai.bedrock.coherecommandr.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + + contextRunner.run(context -> { + + BedrockCohereCommandRChatModel chatModel = context.getBean(BedrockCohereCommandRChatModel.class); + + var userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + BedrockCohereCommandRChatOptions.builder().withFunction("weatherFunction").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + }); + } + + @Configuration + static class Config { + + @Bean + @Description("Get the weather in location. Return temperature in 36°F or 36°C format.") + public Function weatherFunction() { + return new MockWeatherService(); + } + + // Relies on the Request's JsonClassDescription annotation to provide the + // function description. + @Bean + public Function weatherFunction3() { + MockWeatherService weatherService = new MockWeatherService(); + return (weatherService::apply); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithPromptFunctionIT.java new file mode 100644 index 0000000000..8c38a1713a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/FunctionCallWithPromptFunctionIT.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereCommandRChatAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatOptions; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel.CohereCommandRChatModel; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class FunctionCallWithPromptFunctionIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallWithPromptFunctionIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.coherecommandr.chat.model=" + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + "spring.ai.bedrock.coherecommandr.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockCohereCommandRChatAutoConfiguration.class)); + + @Test + void functionCallTest() { + contextRunner.run(context -> { + + BedrockCohereCommandRChatModel chatModel = context.getBean(BedrockCohereCommandRChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + var promptOptions = BedrockCohereCommandRChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/MockWeatherService.java new file mode 100644 index 0000000000..12972e3bdb --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/tool/MockWeatherService.java @@ -0,0 +1,92 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * Mock 3rd party weather service. + * + * @author Wei Jiang + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request( + @JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city, example: San Francisco") Object location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temperature, double feels_like, double temp_min, double temp_max, int pressure, + int humidity, Unit unit) { + } + + @Override + public Response apply(Request request) { + double temperature = 0; + if (request.location().toString().contains("Paris")) { + temperature = 15; + } + else if (request.location().toString().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().toString().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java index ace30a03d1..3ff218ddb6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/jurassic2/BedrockAi21Jurassic2ChatAutoConfigurationIT.java @@ -19,26 +19,26 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; import org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration; import org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel; -import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; +import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatModel.Ai21Jurassic2ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import software.amazon.awssdk.regions.Region; import java.util.List; -import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; /** * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -50,18 +50,11 @@ public class BedrockAi21Jurassic2ChatAutoConfigurationIT { "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), - "spring.ai.bedrock.jurassic2.chat.model=" - + Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel.AI21_J2_ULTRA_V1.id(), + "spring.ai.bedrock.jurassic2.chat.model=" + Ai21Jurassic2ChatModel.AI21_J2_ULTRA_V1.id(), "spring.ai.bedrock.jurassic2.chat.options.temperature=0.5", "spring.ai.bedrock.jurassic2.chat.options.maxGenLen=500") - .withConfiguration(AutoConfigurations.of(BedrockAi21Jurassic2ChatAutoConfiguration.class)); - - private final Message systemMessage = new SystemPromptTemplate(""" - You are a helpful AI assistant. Your name is {name}. - You are an AI assistant that helps people find information. - Your name is {name} - You should reply to the user's request with your name and also in the style of a {voice}. - """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAi21Jurassic2ChatAutoConfiguration.class)); private final UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); @@ -70,7 +63,7 @@ public class BedrockAi21Jurassic2ChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockAi21Jurassic2ChatModel ai21Jurassic2ChatModel = context.getBean(BedrockAi21Jurassic2ChatModel.class); - ChatResponse response = ai21Jurassic2ChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = ai21Jurassic2ChatModel.call(new Prompt(List.of(userMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -85,7 +78,8 @@ public void propertiesTest() { "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), "spring.ai.bedrock.jurassic2.chat.options.temperature=0.55", "spring.ai.bedrock.jurassic2.chat.options.maxTokens=123") - .withConfiguration(AutoConfigurations.of(BedrockAi21Jurassic2ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAi21Jurassic2ChatAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(BedrockAi21Jurassic2ChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -107,7 +101,8 @@ public void chatCompletionDisabled() { // It is disabled by default new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockAi21Jurassic2ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAi21Jurassic2ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAi21Jurassic2ChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockAi21Jurassic2ChatModel.class)).isEmpty(); @@ -115,7 +110,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.jurassic2.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockAi21Jurassic2ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAi21Jurassic2ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAi21Jurassic2ChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockAi21Jurassic2ChatModel.class)).isNotEmpty(); @@ -123,7 +119,8 @@ public void chatCompletionDisabled() { // Explicitly disable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.jurassic2.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockAi21Jurassic2ChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockAi21Jurassic2ChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockAi21Jurassic2ChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockAi21Jurassic2ChatModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java index f1ed73b8b1..2b68bea15d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama/BedrockLlamaChatAutoConfigurationIT.java @@ -22,13 +22,15 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel; +import org.springframework.ai.bedrock.llama.BedrockLlamaChatModel.LlamaChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; -import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatModel; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; @@ -56,7 +58,8 @@ public class BedrockLlamaChatAutoConfigurationIT { "spring.ai.bedrock.llama.chat.model=" + LlamaChatModel.LLAMA3_70B_INSTRUCT_V1.id(), "spring.ai.bedrock.llama.chat.options.temperature=0.5", "spring.ai.bedrock.llama.chat.options.maxGenLen=500") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockLlamaChatAutoConfiguration.class)); private final Message systemMessage = new SystemPromptTemplate(""" You are a helpful AI assistant. Your name is {name}. @@ -109,7 +112,8 @@ public void propertiesTest() { "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.llama.chat.options.temperature=0.55", "spring.ai.bedrock.llama.chat.options.maxGenLen=123") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockLlamaChatAutoConfiguration.class)) .run(context -> { var llamaChatProperties = context.getBean(BedrockLlamaChatProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -130,7 +134,9 @@ public void propertiesTest() { public void chatCompletionDisabled() { // It is disabled by default - new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) + new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockLlamaChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockLlamaChatModel.class)).isEmpty(); @@ -138,7 +144,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockLlamaChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockLlamaChatModel.class)).isNotEmpty(); @@ -146,7 +153,8 @@ public void chatCompletionDisabled() { // Explicitly disable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.llama.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockLlamaChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockLlamaChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockLlamaChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockLlamaChatModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfigurationIT.java new file mode 100644 index 0000000000..7d0b59842b --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfigurationIT.java @@ -0,0 +1,160 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel.MistralChatModel; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockMistralChatAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.mistral.chat.model=" + MistralChatModel.MISTRAL_SMALL.id(), + "spring.ai.bedrock.mistral.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)); + + private final Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + private final UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + @Test + public void chatCompletion() { + contextRunner.run(context -> { + BedrockMistralChatModel mistralChatModel = context.getBean(BedrockMistralChatModel.class); + ChatResponse response = mistralChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + }); + } + + @Test + public void chatCompletionStreaming() { + contextRunner.run(context -> { + + BedrockMistralChatModel mistralChatModel = context.getBean(BedrockMistralChatModel.class); + + Flux response = mistralChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + + List responses = response.collectList().block(); + assertThat(responses.size()).isGreaterThan(2); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + }); + } + + @Test + public void propertiesTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", + "spring.ai.bedrock.mistral.chat.model=MODEL_XYZ", + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.mistral.chat.options.temperature=0.55") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + var mistralChatProperties = context.getBean(BedrockMistralChatProperties.class); + var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); + + assertThat(mistralChatProperties.isEnabled()).isTrue(); + assertThat(awsProperties.getRegion()).isEqualTo(Region.US_EAST_1.id()); + + assertThat(mistralChatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(mistralChatProperties.getModel()).isEqualTo("MODEL_XYZ"); + + assertThat(awsProperties.getAccessKey()).isEqualTo("ACCESS_KEY"); + assertThat(awsProperties.getSecretKey()).isEqualTo("SECRET_KEY"); + }); + } + + @Test + public void chatCompletionDisabled() { + + // It is disabled by default + new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockMistralChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockMistralChatModel.class)).isEmpty(); + }); + + // Explicitly enable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockMistralChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(BedrockMistralChatModel.class)).isNotEmpty(); + }); + + // Explicitly disable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockMistralChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockMistralChatModel.class)).isEmpty(); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/FunctionCallWithFunctionBeanIT.java new file mode 100644 index 0000000000..084985d374 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/FunctionCallWithFunctionBeanIT.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.mistral.BedrockMistralChatAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.mistral.tool.MockWeatherService.Request; +import org.springframework.ai.autoconfigure.bedrock.mistral.tool.MockWeatherService.Response; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel.MistralChatModel; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatOptions; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; + +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +class FunctionCallWithFunctionBeanIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallWithFunctionBeanIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.mistral.chat.model=" + MistralChatModel.MISTRAL_LARGE.id(), + "spring.ai.bedrock.mistral.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @Test + void functionCallTest() { + + contextRunner.run(context -> { + + BedrockMistralChatModel chatModel = context.getBean(BedrockMistralChatModel.class); + + var userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + BedrockMistralChatOptions.builder().withFunction("weatherFunction").build())); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + + }); + } + + @Configuration + static class Config { + + @Bean + @Description("Get the weather in location. Return temperature in 36°F or 36°C format.") + public Function weatherFunction() { + return new MockWeatherService(); + } + + // Relies on the Request's JsonClassDescription annotation to provide the + // function description. + @Bean + public Function weatherFunction3() { + MockWeatherService weatherService = new MockWeatherService(); + return (weatherService::apply); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/FunctionCallWithPromptFunctionIT.java new file mode 100644 index 0000000000..0e0bfcfcbb --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/FunctionCallWithPromptFunctionIT.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral.tool; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.bedrock.mistral.BedrockMistralChatAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel.MistralChatModel; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatOptions; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class FunctionCallWithPromptFunctionIT { + + private final Logger logger = LoggerFactory.getLogger(FunctionCallWithPromptFunctionIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.mistral.chat.model=" + MistralChatModel.MISTRAL_LARGE.id(), + "spring.ai.bedrock.mistral.chat.options.temperature=0.5") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)); + + @Test + void functionCallTest() { + contextRunner.run(context -> { + + BedrockMistralChatModel chatModel = context.getBean(BedrockMistralChatModel.class); + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + var promptOptions = BedrockMistralChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("CurrentWeatherService") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/MockWeatherService.java new file mode 100644 index 0000000000..0199dc589c --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/tool/MockWeatherService.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * Mock 3rd party weather service. + * + * @author Wei Jiang + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temperature, double feels_like, double temp_min, double temp_max, int pressure, + int humidity, Unit unit) { + } + + @Override + public Response apply(Request request) { + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + } + +} \ No newline at end of file diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 94a2fda1b6..2724e12667 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -16,7 +16,6 @@ package org.springframework.ai.autoconfigure.bedrock.titan; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -27,12 +26,12 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.bedrock.api.BedrockConverseApiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.titan.BedrockTitanChatModel; -import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.bedrock.titan.BedrockTitanChatModel.TitanChatModel; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -41,6 +40,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -55,14 +55,8 @@ public class BedrockTitanChatAutoConfigurationIT { "spring.ai.bedrock.titan.chat.model=" + TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), "spring.ai.bedrock.titan.chat.options.temperature=0.5", "spring.ai.bedrock.titan.chat.options.maxTokenCount=500") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)); - - private final Message systemMessage = new SystemPromptTemplate(""" - You are a helpful AI assistant. Your name is {name}. - You are an AI assistant that helps people find information. - Your name is {name} - You should reply to the user's request with your name and also in the style of a {voice}. - """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockTitanChatAutoConfiguration.class)); private final UserMessage userMessage = new UserMessage( "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); @@ -71,7 +65,7 @@ public class BedrockTitanChatAutoConfigurationIT { public void chatCompletion() { contextRunner.run(context -> { BedrockTitanChatModel chatModel = context.getBean(BedrockTitanChatModel.class); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage))); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage))); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); }); } @@ -82,7 +76,7 @@ public void chatCompletionStreaming() { BedrockTitanChatModel chatModel = context.getBean(BedrockTitanChatModel.class); - Flux response = chatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatModel.stream(new Prompt(List.of(userMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); @@ -110,7 +104,8 @@ public void propertiesTest() { "spring.ai.bedrock.titan.chat.options.topP=0.55", "spring.ai.bedrock.titan.chat.options.stopSequences=END1,END2", "spring.ai.bedrock.titan.chat.options.maxTokenCount=123") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockTitanChatAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(BedrockTitanChatProperties.class); var aswProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -134,7 +129,9 @@ public void propertiesTest() { public void chatCompletionDisabled() { // It is disabled by default - new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)) + new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockTitanChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockTitanChatModel.class)).isEmpty(); @@ -142,7 +139,8 @@ public void chatCompletionDisabled() { // Explicitly enable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockTitanChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockTitanChatModel.class)).isNotEmpty(); @@ -150,7 +148,8 @@ public void chatCompletionDisabled() { // Explicitly disable the chat auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockTitanChatAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockConverseApiAutoConfiguration.class, BedrockTitanChatAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanChatProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockTitanChatModel.class)).isEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 5a5a2ad4c1..e709a3d093 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel; import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; @@ -47,7 +48,8 @@ public class BedrockTitanEmbeddingAutoConfigurationIT { "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), "spring.ai.bedrock.titan.embedding.model=" + TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id()) - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)); @Test public void singleTextEmbedding() { @@ -87,7 +89,8 @@ public void propertiesTest() { "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), "spring.ai.bedrock.titan.embedding.model=MODEL_XYZ", "spring.ai.bedrock.titan.embedding.inputType=TEXT") - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { var properties = context.getBean(BedrockTitanEmbeddingProperties.class); var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); @@ -108,7 +111,8 @@ public void embeddingDisabled() { // It is disabled by default new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanEmbeddingProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockTitanEmbeddingModel.class)).isEmpty(); @@ -116,7 +120,8 @@ public void embeddingDisabled() { // Explicitly enable the embedding auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=true") - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(BedrockTitanEmbeddingModel.class)).isNotEmpty(); @@ -124,7 +129,8 @@ public void embeddingDisabled() { // Explicitly disable the embedding auto-configuration. new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.titan.embedding.enabled=false") - .withConfiguration(AutoConfigurations.of(BedrockTitanEmbeddingAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockTitanEmbeddingAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(BedrockTitanEmbeddingProperties.class)).isEmpty(); assertThat(context.getBeansOfType(BedrockTitanEmbeddingModel.class)).isEmpty();