Skip to content

Commit

Permalink
Re-implementing Amazon Bedrock Chat model with Amazon Bedrock Convers…
Browse files Browse the repository at this point in the history
…e API.
  • Loading branch information
wmz7year committed Jun 4, 2024
1 parent e4ee01a commit a96e2b9
Show file tree
Hide file tree
Showing 44 changed files with 1,231 additions and 980 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock;

import java.util.HashMap;

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;

import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics;

/**
* {@link ChatResponseMetadata} implementation for {@literal Amazon Bedrock}.
*
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, latency: %4$sms}";

private final String id;

private final Usage usage;

private final Long latencyMs;

public static BedrockChatResponseMetadata from(ConverseResponse response) {
String requestId = response.responseMetadata().requestId();

BedrockUsage usage = BedrockUsage.from(response.usage());

ConverseMetrics metrics = response.metrics();

return new BedrockChatResponseMetadata(requestId, usage, metrics.latencyMs());
}

public static BedrockChatResponseMetadata from(ConverseStreamMetadataEvent converseStreamMetadataEvent) {
BedrockUsage usage = BedrockUsage.from(converseStreamMetadataEvent.usage());

ConverseStreamMetrics metrics = converseStreamMetadataEvent.metrics();

return new BedrockChatResponseMetadata(null, usage, metrics.latencyMs());
}

protected BedrockChatResponseMetadata(String id, BedrockUsage usage, Long latencyMs) {
this.id = id;
this.usage = usage;
this.latencyMs = latencyMs;
}

public String getId() {
return this.id;
}

public Long getLatencyMs() {
return latencyMs;
}

@Override
public Usage getUsage() {
Usage usage = this.usage;
return usage != null ? usage : new EmptyUsage();
}

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getLatencyMs());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,49 @@
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;

/**
* {@link Usage} implementation for Bedrock API.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockUsage implements Usage {

public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) {
return new BedrockUsage(usage);
return new BedrockUsage(usage.inputTokenCount().longValue(), usage.outputTokenCount().longValue());
}

private final AmazonBedrockInvocationMetrics usage;
public static BedrockUsage from(TokenUsage usage) {
Assert.notNull(usage, "'TokenUsage' must not be null.");

protected BedrockUsage(AmazonBedrockInvocationMetrics usage) {
Assert.notNull(usage, "OpenAI Usage must not be null");
this.usage = usage;
return new BedrockUsage(usage.inputTokens().longValue(), usage.outputTokens().longValue());
}

protected AmazonBedrockInvocationMetrics getUsage() {
return this.usage;
private final Long inputTokens;

private final Long outputTokens;

protected BedrockUsage(Long inputTokens, Long outputTokens) {
this.inputTokens = inputTokens;
this.outputTokens = outputTokens;
}

@Override
public Long getPromptTokens() {
return getUsage().inputTokenCount().longValue();
return inputTokens;
}

@Override
public Long getGenerationTokens() {
return getUsage().outputTokenCount().longValue();
return outputTokens;
}

@Override
public String toString() {
return getUsage().toString();
return "BedrockUsage [inputTokens=" + inputTokens + ", outputTokens=" + outputTokens + "]";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
import org.springframework.util.Assert;

/**
* @deprecated Use {@link BedrockConverseApi} instead.
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
// @formatter:off
@Deprecated
public class AnthropicChatBedrockApi extends
AbstractBedrockApi<AnthropicChatRequest, AnthropicChatResponse, AnthropicChatResponse> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ public class Anthropic3ChatOptions implements ChatOptions {
*/
private @JsonProperty("stop_sequences") List<String> stopSequences;

/**
* The version of the generative to use. The default value is bedrock-2023-05-31.
*/
private @JsonProperty("anthropic_version") String anthropicVersion;
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -101,11 +97,6 @@ public Builder withStopSequences(List<String> stopSequences) {
return this;
}

public Builder withAnthropicVersion(String anthropicVersion) {
this.options.setAnthropicVersion(anthropicVersion);
return this;
}

public Anthropic3ChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -155,21 +146,12 @@ public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

public String getAnthropicVersion() {
return this.anthropicVersion;
}

public void setAnthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
}

public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokens(fromOptions.getMaxTokens())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withAnthropicVersion(fromOptions.getAnthropicVersion())
.build();
}

Expand Down
Loading

0 comments on commit a96e2b9

Please sign in to comment.