Skip to content

Commit

Permalink
Amazon Bedrock Chat adds tool support.
Browse files Browse the repository at this point in the history
  • Loading branch information
wmz7year committed Jun 5, 2024
1 parent 49b3326 commit a3ad6d5
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
*/
package org.springframework.ai.bedrock.anthropic3;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
Expand All @@ -31,7 +39,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class Anthropic3ChatOptions implements ChatOptions {
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {

// @formatter:off
/**
Expand Down Expand Up @@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
*/
private @JsonProperty("stop_sequences") List<String> stopSequences;

/**
* Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests. Functions with those names must exist in the
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
* are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat
* completion requests. This could impact the token count and the billing. If the
* functions is set in a prompt options, then the enabled functions are only active
* for the duration of this prompt execution.
*/
@NestedConfigurationProperty
@JsonIgnore
private Set<String> functions = new HashSet<>();
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
return this;
}

public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}

public Builder withFunctions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}

public Builder withFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}

public Anthropic3ChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
}

@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
this.functionCallbacks = functionCallbacks;
}

@Override
public Set<String> getFunctions() {
return this.functions;
}

@Override
public void setFunctions(Set<String> functions) {
Assert.notNull(functions, "Function must not be null");
this.functions = functions;
}

public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokens(fromOptions.getMaxTokens())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@
package org.springframework.ai.bedrock.anthropic3;

import reactor.core.publisher.Flux;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;

import java.util.List;

import org.springframework.ai.bedrock.api.BedrockConverseApi;
import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest;
import org.springframework.ai.bedrock.api.BedrockConverseApiUtils;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;

/**
Expand All @@ -38,7 +43,9 @@
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
public class BedrockAnthropic3ChatModel
extends AbstractFunctionCallSupport<Generation, BedrockConverseRequest, ChatResponse>
implements ChatModel, StreamingChatModel {

private final String modelId;

Expand All @@ -56,6 +63,13 @@ public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3Chat
}

public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options) {
this(modelId, converseApi, options, null);
}

public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);

Assert.notNull(modelId, "modelId must not be null.");
Assert.notNull(converseApi, "BedrockConverseApi must not be null.");
Assert.notNull(options, "Anthropic3ChatOptions must not be null.");
Expand All @@ -69,17 +83,16 @@ public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi
public ChatResponse call(Prompt prompt) {
Assert.notNull(prompt, "Prompt must not be null.");

var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions);

ConverseResponse response = this.converseApi.converse(request);
var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions);

return BedrockConverseApiUtils.convertConverseResponse(response);
return this.callWithFunctionSupport(request);
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
Assert.notNull(prompt, "Prompt must not be null.");

// TODO
var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions);

Flux<ConverseStreamOutput> fluxResponse = this.converseApi.converseStream(request);
Expand All @@ -92,6 +105,43 @@ public ChatOptions getDefaultOptions() {
return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
}

@Override
protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest,
Generation responseMessage, List<Generation> conversationHistory) {
// TODO Auto-generated method stub
return null;
}

@Override
protected List<Generation> doGetUserMessages(BedrockConverseRequest request) {
// TODO Auto-generated method stub
return null;
}

@Override
protected Generation doGetToolResponseMessage(ChatResponse response) {
// TODO Auto-generated method stub
return null;
}

@Override
protected ChatResponse doChatCompletion(BedrockConverseRequest request) {
// TODO Auto-generated method stub
return null;
}

@Override
protected Flux<ChatResponse> doChatCompletionStream(BedrockConverseRequest request) {
// TODO Auto-generated method stub
return null;
}

@Override
protected boolean isToolFunctionCall(ChatResponse response) {
// TODO Auto-generated method stub
return false;
}

/**
* Anthropic3 models version.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package org.springframework.ai.bedrock.api;

import java.time.Duration;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
Expand All @@ -30,6 +32,7 @@
import reactor.core.publisher.Sinks.EmitResult;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
Expand All @@ -38,6 +41,8 @@
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;

/**
* Amazon Bedrock Converse API, It provides the basic functionality to invoke the Bedrock
Expand Down Expand Up @@ -177,6 +182,41 @@ public Region getRegion() {
return this.region;
}

/**
* BedrockConverseRequest encapsulates the request parameters for the Amazon Bedrock
* Converse Api.
*
* @param modelId The Amazon Bedrock Model Id.
* @param messages The messages that you want to send to the model.
* @param systemMessages A system prompt to pass to the model.
* @param additionalModelRequestFields Additional inference parameters that the model
* supports, beyond the base set of inference parameters that Converse supports in the
* inferenceConfig field.
*/
public record BedrockConverseRequest(String modelId, List<Message> messages,
List<SystemContentBlock> systemMessages, Document additionalModelRequestFields) {

}

/**
* Invoke the model and return the response.
*
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
* @param bedrockConverseRequest Model invocation request.
* @return The model invocation response.
*/
public ChatResponse converse(BedrockConverseRequest bedrockConverseRequest) {
Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null");

ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(bedrockConverseRequest);

ConverseResponse converseResponse = converse(converseRequest);

return BedrockConverseApiUtils.convertConverseResponse(converseResponse);
}

/**
* Invoke the model and return the response.
*
Expand All @@ -194,6 +234,26 @@ public ConverseResponse converse(ConverseRequest converseRequest) {
});
}

/**
* Invoke the model and return the response stream.
*
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
* @param bedrockConverseRequest Model invocation request.
* @return The model invocation response stream.
*/
public Flux<ChatResponse> converseStream(BedrockConverseRequest bedrockConverseRequest) {
Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null");

ConverseStreamRequest converseStreamRequest = BedrockConverseApiUtils
.createConverseStreamRequest(bedrockConverseRequest);

return converseStream(converseStreamRequest)
.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output));

}

/**
* Invoke the model and return the response stream.
*
Expand Down
Loading

0 comments on commit a3ad6d5

Please sign in to comment.