From a3ad6d52517918b0d977913fa3f094141cf127e4 Mon Sep 17 00:00:00 2001 From: wmz7year Date: Wed, 5 Jun 2024 08:35:11 +0800 Subject: [PATCH] Amazon Bedrock Chat adds tool support. --- .../anthropic3/Anthropic3ChatOptions.java | 76 ++++++++++++- .../BedrockAnthropic3ChatModel.java | 62 +++++++++- .../ai/bedrock/api/BedrockConverseApi.java | 60 ++++++++++ .../bedrock/api/BedrockConverseApiUtils.java | 107 ++++++++++++++---- .../ai/bedrock/MockWeatherService.java | 89 +++++++++++++++ .../BedrockAnthropic3ChatModelIT.java | 27 +++++ 6 files changed, 392 insertions(+), 29 deletions(-) create mode 100644 models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java index 98fe80bb8d0..208925f6d62 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java @@ -15,12 +15,20 @@ */ package org.springframework.ai.bedrock.anthropic3; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; /** * Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options. @@ -31,7 +39,7 @@ * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class Anthropic3ChatOptions implements ChatOptions { +public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions { // @formatter:off /** @@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions { */ private @JsonProperty("stop_sequences") List stopSequences; + /** + * Tool Function Callbacks to register with the ChatModel. For Prompt + * Options the functionCallbacks are automatically enabled for the duration of the + * prompt execution. For Default Options the functionCallbacks are registered but + * disabled by default. Use the enableFunctions to set the functions from the registry + * to be used by the ChatModel chat completion requests. + */ + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. Functions with those names must exist in the + * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions + * are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat + * completion requests. This could impact the token count and the billing. If the + * functions is set in a prompt options, then the enabled functions are only active + * for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); // @formatter:on public static Builder builder() { @@ -101,6 +134,23 @@ public Builder withStopSequences(List stopSequences) { return this; } + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + public Anthropic3ChatOptions build() { return this.options; } @@ -150,12 +200,36 @@ public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + Assert.notNull(functions, "Function must not be null"); + this.functions = functions; + } + public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { return builder().withTemperature(fromOptions.getTemperature()) .withMaxTokens(fromOptions.getMaxTokens()) .withTopK(fromOptions.getTopK()) .withTopP(fromOptions.getTopP()) .withStopSequences(fromOptions.getStopSequences()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) .build(); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index bb7d1b96905..fcaa4dd47e8 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -16,17 +16,22 @@ package org.springframework.ai.bedrock.anthropic3; import reactor.core.publisher.Flux; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import java.util.List; + import org.springframework.ai.bedrock.api.BedrockConverseApi; +import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest; import org.springframework.ai.bedrock.api.BedrockConverseApiUtils; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelDescription; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.util.Assert; /** @@ -38,7 +43,9 @@ * @author Wei Jiang * @since 1.0.0 */ -public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel { +public class BedrockAnthropic3ChatModel + extends AbstractFunctionCallSupport + implements ChatModel, StreamingChatModel { private final String modelId; @@ -56,6 +63,13 @@ public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3Chat } public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options) { + this(modelId, converseApi, options, null); + } + + public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options, + FunctionCallbackContext functionCallbackContext) { + super(functionCallbackContext); + Assert.notNull(modelId, "modelId must not be null."); Assert.notNull(converseApi, "BedrockConverseApi must not be null."); Assert.notNull(options, "Anthropic3ChatOptions must not be null."); @@ -69,17 +83,16 @@ public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi public ChatResponse call(Prompt prompt) { Assert.notNull(prompt, "Prompt must not be null."); - var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions); - - ConverseResponse response = this.converseApi.converse(request); + var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions); - return BedrockConverseApiUtils.convertConverseResponse(response); + return this.callWithFunctionSupport(request); } @Override public Flux stream(Prompt prompt) { Assert.notNull(prompt, "Prompt must not be null."); + // TODO var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions); Flux fluxResponse = this.converseApi.converseStream(request); @@ -92,6 +105,43 @@ public ChatOptions getDefaultOptions() { return Anthropic3ChatOptions.fromOptions(this.defaultOptions); } + @Override + protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest, + Generation responseMessage, List conversationHistory) { + // TODO Auto-generated method stub + return null; + } + + @Override + protected List doGetUserMessages(BedrockConverseRequest request) { + // TODO Auto-generated method stub + return null; + } + + @Override + protected Generation doGetToolResponseMessage(ChatResponse response) { + // TODO Auto-generated method stub + return null; + } + + @Override + protected ChatResponse doChatCompletion(BedrockConverseRequest request) { + // TODO Auto-generated method stub + return null; + } + + @Override + protected Flux doChatCompletionStream(BedrockConverseRequest request) { + // TODO Auto-generated method stub + return null; + } + + @Override + protected boolean isToolFunctionCall(ChatResponse response) { + // TODO Auto-generated method stub + return false; + } + /** * Anthropic3 models version. */ diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java index b1f69b5284e..eba5c39e739 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java @@ -17,9 +17,11 @@ package org.springframework.ai.bedrock.api; import java.time.Duration; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.retry.RetryUtils; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -30,6 +32,7 @@ import reactor.core.publisher.Sinks.EmitResult; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import software.amazon.awssdk.core.document.Document; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; @@ -38,6 +41,8 @@ import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; /** * Amazon Bedrock Converse API, It provides the basic functionality to invoke the Bedrock @@ -177,6 +182,41 @@ public Region getRegion() { return this.region; } + /** + * BedrockConverseRequest encapsulates the request parameters for the Amazon Bedrock + * Converse Api. + * + * @param modelId The Amazon Bedrock Model Id. + * @param messages The messages that you want to send to the model. + * @param systemMessages A system prompt to pass to the model. + * @param additionalModelRequestFields Additional inference parameters that the model + * supports, beyond the base set of inference parameters that Converse supports in the + * inferenceConfig field. + */ + public record BedrockConverseRequest(String modelId, List messages, + List systemMessages, Document additionalModelRequestFields) { + + } + + /** + * Invoke the model and return the response. + * + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse + * @param bedrockConverseRequest Model invocation request. + * @return The model invocation response. + */ + public ChatResponse converse(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null"); + + ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(bedrockConverseRequest); + + ConverseResponse converseResponse = converse(converseRequest); + + return BedrockConverseApiUtils.convertConverseResponse(converseResponse); + } + /** * Invoke the model and return the response. * @@ -194,6 +234,26 @@ public ConverseResponse converse(ConverseRequest converseRequest) { }); } + /** + * Invoke the model and return the response stream. + * + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream + * @param bedrockConverseRequest Model invocation request. + * @return The model invocation response stream. + */ + public Flux converseStream(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null"); + + ConverseStreamRequest converseStreamRequest = BedrockConverseApiUtils + .createConverseStreamRequest(bedrockConverseRequest); + + return converseStream(converseStreamRequest) + .map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output)); + + } + /** * Invoke the model and return the response stream. * diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java index 8cb921cd646..12fc2d8a581 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApiUtils.java @@ -25,6 +25,7 @@ import java.util.stream.Collectors; import org.springframework.ai.bedrock.BedrockChatResponseMetadata; +import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatResponse; @@ -64,6 +65,42 @@ public class BedrockConverseApiUtils { private static final ObjectMapper objectMapper = new ObjectMapper(); + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @return Amazon Bedrock Converse encapsulates request. + */ + public static BedrockConverseRequest createBedrockConverseRequest(String modelId, Prompt prompt) { + return createBedrockConverseRequest(modelId, prompt, null); + } + + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @param defaultOptions The default options needs to convert. + * @return Amazon Bedrock Converse encapsulates request. + */ + public static BedrockConverseRequest createBedrockConverseRequest(String modelId, Prompt prompt, + ChatOptions defaultOptions) { + Assert.notNull(modelId, "'modelId' must not be null."); + Assert.notNull(prompt, "'prompt' must not be null."); + + List messages = getPromptMessages(prompt); + + List systemMessages = getPromptSystemContentBlocks(prompt); + + Document additionalModelRequestFields = getChatOptionsAdditionalModelRequestFields(defaultOptions, + prompt.getOptions()); + + return new BedrockConverseRequest(modelId, messages, systemMessages, additionalModelRequestFields); + } + /** * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It * will merge default options and runtime options to converse inference parameters. @@ -88,24 +125,44 @@ public static ConverseRequest createConverseRequest(String modelId, Prompt promp * @return Amazon Bedrock Converse request. */ public static ConverseRequest createConverseRequest(String modelId, Prompt prompt, ChatOptions defaultOptions) { - Assert.notNull(modelId, "'modelId' must not be null."); - Assert.notNull(prompt, "'prompt' must not be null."); + BedrockConverseRequest bedrockConverseRequest = createBedrockConverseRequest(modelId, prompt, defaultOptions); - List systemMessages = getPromptSystemContentBlocks(prompt); - - List userMessages = getPromptMessages(prompt); + return createConverseRequest(bedrockConverseRequest); + } - Document additionalModelRequestFields = getChatOptionsAdditionalModelRequestFields(defaultOptions, - prompt.getOptions()); + /** + * Convert {@link Prompt} to {@link ConverseRequest} with model id and options. It + * will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax + * + * @param bedrockConverseRequest The Amazon Bedrock Converse encapsulates request. + * @return Amazon Bedrock Converse request. + */ + public static ConverseRequest createConverseRequest(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null."); return ConverseRequest.builder() - .modelId(modelId) - .messages(userMessages) - .system(systemMessages) - .additionalModelRequestFields(additionalModelRequestFields) + .modelId(bedrockConverseRequest.modelId()) + .messages(bedrockConverseRequest.messages()) + .system(bedrockConverseRequest.systemMessages()) + .additionalModelRequestFields(bedrockConverseRequest.additionalModelRequestFields()) .build(); } + /** + * Convert {@link Prompt} to {@link ConverseStreamRequest} with model id and options. + * It will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + * + * @param modelId The Amazon Bedrock Model Id. + * @param prompt The prompt that needs to convert. + * @param defaultOptions The default options needs to convert. + * @return Amazon Bedrock Converse stream request. + */ + public static ConverseStreamRequest createConverseStreamRequest(String modelId, Prompt prompt) { + return createConverseStreamRequest(modelId, prompt, null); + } + /** * Convert {@link Prompt} to {@link ConverseStreamRequest} with model id and options. * It will merge default options and runtime options to converse inference parameters. @@ -118,21 +175,27 @@ public static ConverseRequest createConverseRequest(String modelId, Prompt promp */ public static ConverseStreamRequest createConverseStreamRequest(String modelId, Prompt prompt, ChatOptions defaultOptions) { - Assert.notNull(modelId, "'modelId' must not be null."); - Assert.notNull(prompt, "'prompt' must not be null."); + BedrockConverseRequest bedrockConverseRequest = createBedrockConverseRequest(modelId, prompt, defaultOptions); - List systemMessages = getPromptSystemContentBlocks(prompt); - - List userMessages = getPromptMessages(prompt); + return createConverseStreamRequest(bedrockConverseRequest); + } - Document additionalModelRequestFields = getChatOptionsAdditionalModelRequestFields(defaultOptions, - prompt.getOptions()); + /** + * Convert {@link Prompt} to {@link ConverseStreamRequest} with model id and options. + * It will merge default options and runtime options to converse inference parameters. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + * + * @param bedrockConverseRequest The Amazon Bedrock Converse encapsulates request. + * @return Amazon Bedrock Converse stream request. + */ + public static ConverseStreamRequest createConverseStreamRequest(BedrockConverseRequest bedrockConverseRequest) { + Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null."); return ConverseStreamRequest.builder() - .modelId(modelId) - .messages(userMessages) - .system(systemMessages) - .additionalModelRequestFields(additionalModelRequestFields) + .modelId(bedrockConverseRequest.modelId()) + .messages(bedrockConverseRequest.messages()) + .system(bedrockConverseRequest.systemMessages()) + .additionalModelRequestFields(bedrockConverseRequest.additionalModelRequestFields()) .build(); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java new file mode 100644 index 00000000000..78f41f210a9 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/MockWeatherService.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, Unit.C); + } + +} \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java index 75441e76967..3540d6a7054 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -30,6 +31,7 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.MockWeatherService; import org.springframework.ai.bedrock.api.BedrockConverseApi; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -44,6 +46,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -246,6 +249,30 @@ void chatOptions() { assertThat(content).isNotNull(); } + @Test + void functionCallTest() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = Anthropic3ChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15"); + } + @SpringBootConfiguration public static class TestConfiguration {