Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Amazon Bedrock Converse API to re-implementing Bedrock AI Models. #813

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
* OpenAI
* Azure OpenAI
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral)
* Hugging Face
* Google VertexAI (PaLM2, Gemini)
* Mistral AI
Expand Down
1 change: 1 addition & 0 deletions models/spring-ai-bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
- [Mistral Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-mistral.html)

6 changes: 6 additions & 0 deletions models/spring-ai-bedrock/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock;

import java.util.HashMap;

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;

import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics;

/**
* {@link ChatResponseMetadata} implementation for {@literal Amazon Bedrock}.
*
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, latency: %4$sms}";

private final String id;

private final Usage usage;

private final Long latencyMs;

public static BedrockChatResponseMetadata from(ConverseResponse response) {
String requestId = response.responseMetadata().requestId();

BedrockUsage usage = BedrockUsage.from(response.usage());

ConverseMetrics metrics = response.metrics();

return new BedrockChatResponseMetadata(requestId, usage, metrics.latencyMs());
}

public static BedrockChatResponseMetadata from(ConverseStreamMetadataEvent converseStreamMetadataEvent) {
BedrockUsage usage = BedrockUsage.from(converseStreamMetadataEvent.usage());

ConverseStreamMetrics metrics = converseStreamMetadataEvent.metrics();

return new BedrockChatResponseMetadata(null, usage, metrics.latencyMs());
}

protected BedrockChatResponseMetadata(String id, BedrockUsage usage, Long latencyMs) {
this.id = id;
this.usage = usage;
this.latencyMs = latencyMs;
}

public String getId() {
return this.id;
}

public Long getLatencyMs() {
return latencyMs;
}

@Override
public Usage getUsage() {
Usage usage = this.usage;
return usage != null ? usage : new EmptyUsage();
}

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getLatencyMs());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock;

import org.springframework.ai.chat.metadata.ChatGenerationMetadata;

import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;

/**
* Amazon Bedrock Chat model converse interface generation metadata, encapsulating
* information on the completion.
*
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata {

private String stopReason;

private Message message;

private ConverseStreamOutput event;

public BedrockConverseChatGenerationMetadata(String stopReason, ConverseStreamOutput event) {
super();

this.stopReason = stopReason;
this.event = event;
}

public BedrockConverseChatGenerationMetadata(String stopReason, Message message) {
super();

this.stopReason = stopReason;
this.message = message;
}

public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, Message message) {
return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), message);
}

public static BedrockConverseChatGenerationMetadata from(ConverseStreamOutput event) {
String stopReason = null;

if (event instanceof MessageStopEvent messageStopEvent) {
stopReason = messageStopEvent.stopReasonAsString();
}

return new BedrockConverseChatGenerationMetadata(stopReason, event);
}

@Override
public <T> T getContentFilterMetadata() {
return null;
}

@Override
public String getFinishReason() {
return stopReason;
}

public Message getMessage() {
return message;
}

public ConverseStreamOutput getEvent() {
return event;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,49 @@
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;

/**
* {@link Usage} implementation for Bedrock API.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockUsage implements Usage {

public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) {
return new BedrockUsage(usage);
return new BedrockUsage(usage.inputTokenCount().longValue(), usage.outputTokenCount().longValue());
}

private final AmazonBedrockInvocationMetrics usage;
public static BedrockUsage from(TokenUsage usage) {
Assert.notNull(usage, "'TokenUsage' must not be null.");

protected BedrockUsage(AmazonBedrockInvocationMetrics usage) {
Assert.notNull(usage, "OpenAI Usage must not be null");
this.usage = usage;
return new BedrockUsage(usage.inputTokens().longValue(), usage.outputTokens().longValue());
}

protected AmazonBedrockInvocationMetrics getUsage() {
return this.usage;
private final Long inputTokens;

private final Long outputTokens;

protected BedrockUsage(Long inputTokens, Long outputTokens) {
this.inputTokens = inputTokens;
this.outputTokens = outputTokens;
}

@Override
public Long getPromptTokens() {
return getUsage().inputTokenCount().longValue();
return inputTokens;
}

@Override
public Long getGenerationTokens() {
return getUsage().outputTokenCount().longValue();
return outputTokens;
}

@Override
public String toString() {
return getUsage().toString();
return "BedrockUsage [inputTokens=" + inputTokens + ", outputTokens=" + outputTokens + "]";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import com.fasterxml.jackson.annotation.JsonProperty;

/**
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
*
* @author Christian Tzolov
* @author Wei Jiang
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions implements ChatOptions {
Expand All @@ -44,7 +48,7 @@ public class AnthropicChatOptions implements ChatOptions {
* reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate. We
* recommend a limit of 4,000 tokens for optimal performance.
*/
private @JsonProperty("max_tokens_to_sample") Integer maxTokensToSample;
private @JsonProperty("max_tokens") Integer maxTokens;

/**
* Specify the number of token choices the generative uses to generate the next token.
Expand All @@ -62,11 +66,6 @@ public class AnthropicChatOptions implements ChatOptions {
* generating further tokens. The returned text doesn't contain the stop sequence.
*/
private @JsonProperty("stop_sequences") List<String> stopSequences;

/**
* The version of the generative to use. The default value is bedrock-2023-05-31.
*/
private @JsonProperty("anthropic_version") String anthropicVersion;
// @formatter:on

public static Builder builder() {
Expand All @@ -82,8 +81,8 @@ public Builder withTemperature(Float temperature) {
return this;
}

public Builder withMaxTokensToSample(Integer maxTokensToSample) {
this.options.setMaxTokensToSample(maxTokensToSample);
public Builder withMaxTokens(Integer maxTokens) {
this.options.setMaxTokens(maxTokens);
return this;
}

Expand All @@ -102,11 +101,6 @@ public Builder withStopSequences(List<String> stopSequences) {
return this;
}

public Builder withAnthropicVersion(String anthropicVersion) {
this.options.setAnthropicVersion(anthropicVersion);
return this;
}

public AnthropicChatOptions build() {
return this.options;
}
Expand All @@ -122,12 +116,12 @@ public void setTemperature(Float temperature) {
this.temperature = temperature;
}

public Integer getMaxTokensToSample() {
return this.maxTokensToSample;
public Integer getMaxTokens() {
return maxTokens;
}

public void setMaxTokensToSample(Integer maxTokensToSample) {
this.maxTokensToSample = maxTokensToSample;
public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}

@Override
Expand Down Expand Up @@ -156,21 +150,12 @@ public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

public String getAnthropicVersion() {
return this.anthropicVersion;
}

public void setAnthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
}

public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokensToSample(fromOptions.getMaxTokensToSample())
.withMaxTokens(fromOptions.getMaxTokens())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withAnthropicVersion(fromOptions.getAnthropicVersion())
.build();
}

Expand Down
Loading