Skip to content

Commit

Permalink
Add Amazon Bedrock Mistral model support.
Browse files Browse the repository at this point in the history
  • Loading branch information
wmz7year committed Jun 2, 2024
1 parent fa53e2a commit 57518b7
Show file tree
Hide file tree
Showing 20 changed files with 1,506 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
* OpenAI
* Azure OpenAI
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral)
* HuggingFace
* Google VertexAI (PaLM2, Gemini)
* Mistral AI
Expand Down
1 change: 1 addition & 0 deletions models/spring-ai-bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
- [Mistral Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-mistral.html)

6 changes: 6 additions & 0 deletions models/spring-ai-bedrock/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.mistral;

import java.util.List;

import org.springframework.ai.bedrock.BedrockUsage;
import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi;
import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatRequest;
import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatResponse;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import reactor.core.publisher.Flux;

/**
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockMistralChatModel implements ChatModel, StreamingChatModel {

private final MistralChatBedrockApi chatApi;

private final BedrockMistralChatOptions defaultOptions;

/**
* The retry template used to retry the Bedrock API calls.
*/
private final RetryTemplate retryTemplate;

public BedrockMistralChatModel(MistralChatBedrockApi chatApi) {
this(chatApi, BedrockMistralChatOptions.builder().build());
}

public BedrockMistralChatModel(MistralChatBedrockApi chatApi, BedrockMistralChatOptions options) {
this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public BedrockMistralChatModel(MistralChatBedrockApi chatApi, BedrockMistralChatOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(chatApi, "MistralChatBedrockApi must not be null");
Assert.notNull(options, "BedrockMistralChatOptions must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");

this.chatApi = chatApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {

MistralChatRequest request = createRequest(prompt);

return this.retryTemplate.execute(ctx -> {
MistralChatResponse response = this.chatApi.chatCompletion(request);

List<Generation> generations = response.outputs().stream().map(g -> {
return new Generation(g.text());
}).toList();

return new ChatResponse(generations);
});
}

public Flux<ChatResponse> stream(Prompt prompt) {

MistralChatRequest request = createRequest(prompt);

return this.retryTemplate.execute(ctx -> {
return this.chatApi.chatCompletionStream(request).map(g -> {
List<Generation> generations = g.outputs().stream().map(output -> {
Generation generation = new Generation(output.text());

if (g.amazonBedrockInvocationMetrics() != null) {
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
generation.withGenerationMetadata(ChatGenerationMetadata.from(output.stopReason(), usage));
}

return generation;
}).toList();

return new ChatResponse(generations);
});
});
}

/**
* Test access.
*/
MistralChatRequest createRequest(Prompt prompt) {
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());

var request = MistralChatRequest.builder(promptValue)
.withTemperature(this.defaultOptions.getTemperature())
.withTopP(this.defaultOptions.getTopP())
.withTopK(this.defaultOptions.getTopK())
.withMaxTokens(this.defaultOptions.getMaxTokens())
.withStopSequences(this.defaultOptions.getStopSequences())
.build();

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
BedrockMistralChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, BedrockMistralChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, MistralChatRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}

return request;
}

@Override
public ChatOptions getDefaultOptions() {
return defaultOptions;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.mistral;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonInclude.Include;

import org.springframework.ai.chat.prompt.ChatOptions;

/**
* @author Wei Jiang
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class BedrockMistralChatOptions implements ChatOptions {

/**
* The temperature value controls the randomness of the generated text. Use a lower
* value to decrease randomness in the response.
*/
private @JsonProperty("temperature") Float temperature;

/**
* (optional) The maximum cumulative probability of tokens to consider when sampling.
* The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers
* the smallest set of tokens whose probability sum is at least topP.
*/
private @JsonProperty("top_p") Float topP;

/**
* (optional) Specify the number of token choices the generative uses to generate the
* next token.
*/
private @JsonProperty("top_p") Integer topK;

/**
* (optional) Specify the maximum number of tokens to use in the generated response.
*/
private @JsonProperty("max_tokens") Integer maxTokens;

/**
* (optional) Configure up to four sequences that the generative recognizes. After a
* stop sequence, the generative stops generating further tokens. The returned text
* doesn't contain the stop sequence.
*/
private @JsonProperty("stop") List<String> stopSequences;

public static Builder builder() {
return new Builder();
}

public static class Builder {

private final BedrockMistralChatOptions options = new BedrockMistralChatOptions();

public Builder withTemperature(Float temperature) {
this.options.setTemperature(temperature);
return this;
}

public Builder withTopP(Float topP) {
this.options.setTopP(topP);
return this;
}

public Builder withTopK(Integer topK) {
this.options.setTopK(topK);
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
this.options.setMaxTokens(maxTokens);
return this;
}

public Builder withStopSequences(List<String> stopSequences) {
this.options.setStopSequences(stopSequences);
return this;
}

public BedrockMistralChatOptions build() {
return this.options;
}

}

public void setTemperature(Float temperature) {
this.temperature = temperature;
}

@Override
public Float getTemperature() {
return this.temperature;
}

public void setTopP(Float topP) {
this.topP = topP;
}

@Override
public Float getTopP() {
return this.topP;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

@Override
public Integer getTopK() {
return this.topK;
}

public Integer getMaxTokens() {
return maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}

public List<String> getStopSequences() {
return stopSequences;
}

public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

public static BedrockMistralChatOptions fromOptions(BedrockMistralChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTopK(fromOptions.getTopK())
.withMaxTokens(fromOptions.getMaxTokens())
.withStopSequences(fromOptions.getStopSequences())
.build();
}

}
Loading

0 comments on commit 57518b7

Please sign in to comment.