Skip to content

Commit

Permalink
Amazon Bedrock Chat adds tool support.
Browse files Browse the repository at this point in the history
  • Loading branch information
wmz7year committed Jun 6, 2024
1 parent 49b3326 commit b03129d
Show file tree
Hide file tree
Showing 7 changed files with 811 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock;

import java.util.List;

import org.springframework.ai.bedrock.api.BedrockConverseApiUtils;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;

import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;

/**
* Amazon Bedrock Chat model converse interface generation metadata, encapsulating
* information on the completion.
*
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata {

private String stopReason;

private Message message;

private ConverseStreamOutput event;

public BedrockConverseChatGenerationMetadata(String stopReason, ConverseStreamOutput event) {
super();

this.stopReason = stopReason;
this.event = event;
}

public BedrockConverseChatGenerationMetadata(String stopReason, Message message) {
super();

this.stopReason = stopReason;
this.message = message;
}

public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, Message message) {
return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), message);
}

public static BedrockConverseChatGenerationMetadata from(ConverseStreamOutput event) {
String stopReason = null;

if (event instanceof MessageStopEvent messageStopEvent) {
stopReason = messageStopEvent.stopReasonAsString();
}

return new BedrockConverseChatGenerationMetadata(stopReason, event);
}

@Override
public <T> T getContentFilterMetadata() {
return null;
}

@Override
public String getFinishReason() {
return stopReason;
}

public Message getMessage() {
return message;
}

public void setMessage(Message message) {
this.message = message;
}

public ConverseStreamOutput getEvent() {
return event;
}

public boolean isToolUseEvent() {
if (event instanceof ContentBlockStartEvent startEvent) {
if (startEvent.start().toolUse() != null) {
return false;
}
}

if (event instanceof ContentBlockDeltaEvent deltaEvent) {
if (deltaEvent.delta().toolUse() != null) {
return false;
}
}

if (event instanceof MessageStopEvent stopEvent) {
return stopEvent.stopReason() != StopReason.TOOL_USE;
}

return true;
}

public static void generateEventMessage(ChatGenerationMetadata chatGenerationMetadata) {
if (chatGenerationMetadata instanceof BedrockConverseChatGenerationMetadata metadata) {
ConverseStreamOutput event = metadata.getEvent();

if (event instanceof ContentBlockDeltaEvent deltaEvent) {
Message message = BedrockConverseApiUtils.createMessage(
List.of(ContentBlock.builder().text(deltaEvent.delta().text()).build()),
ConversationRole.ASSISTANT);
metadata.setMessage(message);
}
else {
metadata.setMessage(BedrockConverseApiUtils.EMPTY_MESSAGE);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
*/
package org.springframework.ai.bedrock.anthropic3;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
Expand All @@ -31,7 +39,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class Anthropic3ChatOptions implements ChatOptions {
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {

// @formatter:off
/**
Expand Down Expand Up @@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
*/
private @JsonProperty("stop_sequences") List<String> stopSequences;

/**
* Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests. Functions with those names must exist in the
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
* are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat
* completion requests. This could impact the token count and the billing. If the
* functions is set in a prompt options, then the enabled functions are only active
* for the duration of this prompt execution.
*/
@NestedConfigurationProperty
@JsonIgnore
private Set<String> functions = new HashSet<>();
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
return this;
}

public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}

public Builder withFunctions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}

public Builder withFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}

public Anthropic3ChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
}

@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
this.functionCallbacks = functionCallbacks;
}

@Override
public Set<String> getFunctions() {
return this.functions;
}

@Override
public void setFunctions(Set<String> functions) {
Assert.notNull(functions, "Function must not be null");
this.functions = functions;
}

public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokens(fromOptions.getMaxTokens())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}

Expand Down
Loading

0 comments on commit b03129d

Please sign in to comment.