Skip to content

Commit 57518b7

Browse files
author
wmz7year
committedJun 2, 2024
Add Amazon Bedrock Mistral model support.
1 parent fa53e2a commit 57518b7

File tree

20 files changed

+1506
-2
lines changed

20 files changed

+1506
-2
lines changed
 

‎README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
8888
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
8989
* OpenAI
9090
* Azure OpenAI
91-
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
91+
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral)
9292
* HuggingFace
9393
* Google VertexAI (PaLM2, Gemini)
9494
* Mistral AI

‎models/spring-ai-bedrock/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
99
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
1010
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
11+
- [Mistral Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-mistral.html)
1112

‎models/spring-ai-bedrock/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
<version>${project.parent.version}</version>
3030
</dependency>
3131

32+
<dependency>
33+
<groupId>org.springframework.ai</groupId>
34+
<artifactId>spring-ai-retry</artifactId>
35+
<version>${project.parent.version}</version>
36+
</dependency>
37+
3238
<dependency>
3339
<groupId>org.springframework</groupId>
3440
<artifactId>spring-web</artifactId>
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.mistral;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.bedrock.BedrockUsage;
21+
import org.springframework.ai.bedrock.MessageToPromptConverter;
22+
import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi;
23+
import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatRequest;
24+
import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatResponse;
25+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
26+
import org.springframework.ai.chat.metadata.Usage;
27+
import org.springframework.ai.chat.model.ChatModel;
28+
import org.springframework.ai.chat.model.ChatResponse;
29+
import org.springframework.ai.chat.model.Generation;
30+
import org.springframework.ai.chat.model.StreamingChatModel;
31+
import org.springframework.ai.chat.prompt.ChatOptions;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.model.ModelOptionsUtils;
34+
import org.springframework.ai.retry.RetryUtils;
35+
import org.springframework.retry.support.RetryTemplate;
36+
import org.springframework.util.Assert;
37+
38+
import reactor.core.publisher.Flux;
39+
40+
/**
41+
* @author Wei Jiang
42+
* @since 1.0.0
43+
*/
44+
public class BedrockMistralChatModel implements ChatModel, StreamingChatModel {
45+
46+
private final MistralChatBedrockApi chatApi;
47+
48+
private final BedrockMistralChatOptions defaultOptions;
49+
50+
/**
51+
* The retry template used to retry the Bedrock API calls.
52+
*/
53+
private final RetryTemplate retryTemplate;
54+
55+
public BedrockMistralChatModel(MistralChatBedrockApi chatApi) {
56+
this(chatApi, BedrockMistralChatOptions.builder().build());
57+
}
58+
59+
public BedrockMistralChatModel(MistralChatBedrockApi chatApi, BedrockMistralChatOptions options) {
60+
this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
61+
}
62+
63+
public BedrockMistralChatModel(MistralChatBedrockApi chatApi, BedrockMistralChatOptions options,
64+
RetryTemplate retryTemplate) {
65+
Assert.notNull(chatApi, "MistralChatBedrockApi must not be null");
66+
Assert.notNull(options, "BedrockMistralChatOptions must not be null");
67+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
68+
69+
this.chatApi = chatApi;
70+
this.defaultOptions = options;
71+
this.retryTemplate = retryTemplate;
72+
}
73+
74+
@Override
75+
public ChatResponse call(Prompt prompt) {
76+
77+
MistralChatRequest request = createRequest(prompt);
78+
79+
return this.retryTemplate.execute(ctx -> {
80+
MistralChatResponse response = this.chatApi.chatCompletion(request);
81+
82+
List<Generation> generations = response.outputs().stream().map(g -> {
83+
return new Generation(g.text());
84+
}).toList();
85+
86+
return new ChatResponse(generations);
87+
});
88+
}
89+
90+
public Flux<ChatResponse> stream(Prompt prompt) {
91+
92+
MistralChatRequest request = createRequest(prompt);
93+
94+
return this.retryTemplate.execute(ctx -> {
95+
return this.chatApi.chatCompletionStream(request).map(g -> {
96+
List<Generation> generations = g.outputs().stream().map(output -> {
97+
Generation generation = new Generation(output.text());
98+
99+
if (g.amazonBedrockInvocationMetrics() != null) {
100+
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
101+
generation.withGenerationMetadata(ChatGenerationMetadata.from(output.stopReason(), usage));
102+
}
103+
104+
return generation;
105+
}).toList();
106+
107+
return new ChatResponse(generations);
108+
});
109+
});
110+
}
111+
112+
/**
113+
* Test access.
114+
*/
115+
MistralChatRequest createRequest(Prompt prompt) {
116+
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
117+
118+
var request = MistralChatRequest.builder(promptValue)
119+
.withTemperature(this.defaultOptions.getTemperature())
120+
.withTopP(this.defaultOptions.getTopP())
121+
.withTopK(this.defaultOptions.getTopK())
122+
.withMaxTokens(this.defaultOptions.getMaxTokens())
123+
.withStopSequences(this.defaultOptions.getStopSequences())
124+
.build();
125+
126+
if (prompt.getOptions() != null) {
127+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
128+
BedrockMistralChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
129+
ChatOptions.class, BedrockMistralChatOptions.class);
130+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, MistralChatRequest.class);
131+
}
132+
else {
133+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
134+
+ prompt.getOptions().getClass().getSimpleName());
135+
}
136+
}
137+
138+
return request;
139+
}
140+
141+
@Override
142+
public ChatOptions getDefaultOptions() {
143+
return defaultOptions;
144+
}
145+
146+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.mistral;
17+
18+
import java.util.List;
19+
20+
import com.fasterxml.jackson.annotation.JsonInclude;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
23+
24+
import org.springframework.ai.chat.prompt.ChatOptions;
25+
26+
/**
27+
* @author Wei Jiang
28+
* @since 1.0.0
29+
*/
30+
@JsonInclude(Include.NON_NULL)
31+
public class BedrockMistralChatOptions implements ChatOptions {
32+
33+
/**
34+
* The temperature value controls the randomness of the generated text. Use a lower
35+
* value to decrease randomness in the response.
36+
*/
37+
private @JsonProperty("temperature") Float temperature;
38+
39+
/**
40+
* (optional) The maximum cumulative probability of tokens to consider when sampling.
41+
* The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers
42+
* the smallest set of tokens whose probability sum is at least topP.
43+
*/
44+
private @JsonProperty("top_p") Float topP;
45+
46+
/**
47+
* (optional) Specify the number of token choices the generative uses to generate the
48+
* next token.
49+
*/
50+
private @JsonProperty("top_p") Integer topK;
51+
52+
/**
53+
* (optional) Specify the maximum number of tokens to use in the generated response.
54+
*/
55+
private @JsonProperty("max_tokens") Integer maxTokens;
56+
57+
/**
58+
* (optional) Configure up to four sequences that the generative recognizes. After a
59+
* stop sequence, the generative stops generating further tokens. The returned text
60+
* doesn't contain the stop sequence.
61+
*/
62+
private @JsonProperty("stop") List<String> stopSequences;
63+
64+
public static Builder builder() {
65+
return new Builder();
66+
}
67+
68+
public static class Builder {
69+
70+
private final BedrockMistralChatOptions options = new BedrockMistralChatOptions();
71+
72+
public Builder withTemperature(Float temperature) {
73+
this.options.setTemperature(temperature);
74+
return this;
75+
}
76+
77+
public Builder withTopP(Float topP) {
78+
this.options.setTopP(topP);
79+
return this;
80+
}
81+
82+
public Builder withTopK(Integer topK) {
83+
this.options.setTopK(topK);
84+
return this;
85+
}
86+
87+
public Builder withMaxTokens(Integer maxTokens) {
88+
this.options.setMaxTokens(maxTokens);
89+
return this;
90+
}
91+
92+
public Builder withStopSequences(List<String> stopSequences) {
93+
this.options.setStopSequences(stopSequences);
94+
return this;
95+
}
96+
97+
public BedrockMistralChatOptions build() {
98+
return this.options;
99+
}
100+
101+
}
102+
103+
public void setTemperature(Float temperature) {
104+
this.temperature = temperature;
105+
}
106+
107+
@Override
108+
public Float getTemperature() {
109+
return this.temperature;
110+
}
111+
112+
public void setTopP(Float topP) {
113+
this.topP = topP;
114+
}
115+
116+
@Override
117+
public Float getTopP() {
118+
return this.topP;
119+
}
120+
121+
public void setTopK(Integer topK) {
122+
this.topK = topK;
123+
}
124+
125+
@Override
126+
public Integer getTopK() {
127+
return this.topK;
128+
}
129+
130+
public Integer getMaxTokens() {
131+
return maxTokens;
132+
}
133+
134+
public void setMaxTokens(Integer maxTokens) {
135+
this.maxTokens = maxTokens;
136+
}
137+
138+
public List<String> getStopSequences() {
139+
return stopSequences;
140+
}
141+
142+
public void setStopSequences(List<String> stopSequences) {
143+
this.stopSequences = stopSequences;
144+
}
145+
146+
public static BedrockMistralChatOptions fromOptions(BedrockMistralChatOptions fromOptions) {
147+
return builder().withTemperature(fromOptions.getTemperature())
148+
.withTopP(fromOptions.getTopP())
149+
.withTopK(fromOptions.getTopK())
150+
.withMaxTokens(fromOptions.getMaxTokens())
151+
.withStopSequences(fromOptions.getStopSequences())
152+
.build();
153+
}
154+
155+
}

0 commit comments

Comments
 (0)