-
Notifications
You must be signed in to change notification settings - Fork 843
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Bedrock Cohere Command R model support.
- Loading branch information
wmz7year
committed
May 2, 2024
1 parent
7252ba1
commit 9a0feaa
Showing
15 changed files
with
1,962 additions
and
0 deletions.
There are no files selected for viewing
117 changes: 117 additions & 0 deletions
117
.../src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
/* | ||
* Copyright 2023 - 2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.springframework.ai.bedrock.cohere; | ||
|
||
import java.util.List; | ||
|
||
import org.springframework.ai.bedrock.BedrockUsage; | ||
import org.springframework.ai.bedrock.MessageToPromptConverter; | ||
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi; | ||
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest; | ||
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse; | ||
import org.springframework.ai.chat.ChatClient; | ||
import org.springframework.ai.chat.ChatResponse; | ||
import org.springframework.ai.chat.Generation; | ||
import org.springframework.ai.chat.StreamingChatClient; | ||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata; | ||
import org.springframework.ai.chat.metadata.Usage; | ||
import org.springframework.ai.chat.prompt.ChatOptions; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.model.ModelOptionsUtils; | ||
import org.springframework.util.Assert; | ||
|
||
import reactor.core.publisher.Flux; | ||
|
||
/** | ||
* @author Wei Jiang | ||
* @since 1.0.0 | ||
*/ | ||
public class BedrockCohereCommandRChatClient implements ChatClient, StreamingChatClient { | ||
|
||
private final CohereCommandRChatBedrockApi chatApi; | ||
|
||
private final BedrockCohereCommandRChatOptions defaultOptions; | ||
|
||
public BedrockCohereCommandRChatClient(CohereCommandRChatBedrockApi chatApi) { | ||
this(chatApi, BedrockCohereCommandRChatOptions.builder().build()); | ||
} | ||
|
||
public BedrockCohereCommandRChatClient(CohereCommandRChatBedrockApi chatApi, | ||
BedrockCohereCommandRChatOptions options) { | ||
Assert.notNull(chatApi, "CohereCommandRChatBedrockApi must not be null"); | ||
Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null"); | ||
|
||
this.chatApi = chatApi; | ||
this.defaultOptions = options; | ||
} | ||
|
||
@Override | ||
public ChatResponse call(Prompt prompt) { | ||
CohereCommandRChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); | ||
|
||
Generation generation = new Generation(response.text()); | ||
|
||
return new ChatResponse(List.of(generation)); | ||
} | ||
|
||
@Override | ||
public Flux<ChatResponse> stream(Prompt prompt) { | ||
return this.chatApi.chatCompletionStream(this.createRequest(prompt)).map(g -> { | ||
if (g.isFinished()) { | ||
String finishReason = g.finishReason().name(); | ||
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); | ||
return new ChatResponse(List | ||
.of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); | ||
} | ||
return new ChatResponse(List.of(new Generation(g.text()))); | ||
}); | ||
} | ||
|
||
CohereCommandRChatRequest createRequest(Prompt prompt) { | ||
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); | ||
|
||
var request = CohereCommandRChatRequest.builder(promptValue) | ||
.withSearchQueriesOnly(this.defaultOptions.getSearchQueriesOnly()) | ||
.withPreamble(this.defaultOptions.getPreamble()) | ||
.withMaxTokens(this.defaultOptions.getMaxTokens()) | ||
.withTemperature(this.defaultOptions.getTemperature()) | ||
.withTopP(this.defaultOptions.getTopP()) | ||
.withTopK(this.defaultOptions.getTopK()) | ||
.withPromptTruncation(this.defaultOptions.getPromptTruncation()) | ||
.withFrequencyPenalty(this.defaultOptions.getFrequencyPenalty()) | ||
.withPresencePenalty(this.defaultOptions.getPresencePenalty()) | ||
.withSeed(this.defaultOptions.getSeed()) | ||
.withReturnPrompt(this.defaultOptions.getReturnPrompt()) | ||
.withStopSequences(this.defaultOptions.getStopSequences()) | ||
.withRawPrompting(this.defaultOptions.getRawPrompting()) | ||
.build(); | ||
|
||
if (prompt.getOptions() != null) { | ||
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { | ||
BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, | ||
ChatOptions.class, BedrockCohereCommandRChatOptions.class); | ||
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereCommandRChatRequest.class); | ||
} | ||
else { | ||
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " | ||
+ prompt.getOptions().getClass().getSimpleName()); | ||
} | ||
} | ||
|
||
return request; | ||
} | ||
|
||
} |
276 changes: 276 additions & 0 deletions
276
...src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
/* | ||
* Copyright 2023 - 2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.springframework.ai.bedrock.cohere; | ||
|
||
import java.util.List; | ||
|
||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import com.fasterxml.jackson.annotation.JsonInclude.Include; | ||
|
||
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation; | ||
import org.springframework.ai.chat.prompt.ChatOptions; | ||
|
||
/** | ||
* @author Wei Jiang | ||
* @since 1.0.0 | ||
*/ | ||
@JsonInclude(Include.NON_NULL) | ||
public class BedrockCohereCommandRChatOptions implements ChatOptions { | ||
|
||
// @formatter:off | ||
/** | ||
* (optional) When enabled, it will only generate potential search queries without performing | ||
* searches or providing a response. | ||
*/ | ||
@JsonProperty("search_queries_only") Boolean searchQueriesOnly; | ||
/** | ||
* (optional) Overrides the default preamble for search query generation. | ||
*/ | ||
@JsonProperty("preamble") String preamble; | ||
/** | ||
* (optional) Specify the maximum number of tokens to use in the generated response. | ||
*/ | ||
@JsonProperty("max_tokens") Integer maxTokens; | ||
/** | ||
* (optional) Use a lower value to decrease randomness in the response. | ||
*/ | ||
@JsonProperty("temperature") Float temperature; | ||
/** | ||
* Top P. Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable. | ||
*/ | ||
@JsonProperty("p") Float topP; | ||
/** | ||
* Top K. Specify the number of token choices the model uses to generate the next token. | ||
*/ | ||
@JsonProperty("k") Integer topK; | ||
/** | ||
* (optional) Dictates how the prompt is constructed. | ||
*/ | ||
@JsonProperty("prompt_truncation") PromptTruncation promptTruncation; | ||
/** | ||
* (optional) Used to reduce repetitiveness of generated tokens. | ||
*/ | ||
@JsonProperty("frequency_penalty") Float frequencyPenalty; | ||
/** | ||
* (optional) Used to reduce repetitiveness of generated tokens. | ||
*/ | ||
@JsonProperty("presence_penalty") Float presencePenalty; | ||
/** | ||
* (optional) Specify the best effort to sample tokens deterministically. | ||
*/ | ||
@JsonProperty("seed") Integer seed; | ||
/** | ||
* (optional) Specify true to return the full prompt that was sent to the model. | ||
*/ | ||
@JsonProperty("return_prompt") Boolean returnPrompt; | ||
/** | ||
* (optional) A list of stop sequences. | ||
*/ | ||
@JsonProperty("stop_sequences") List<String> stopSequences; | ||
/** | ||
* (optional) Specify true, to send the user’s message to the model without any preprocessing. | ||
*/ | ||
@JsonProperty("raw_prompting") Boolean rawPrompting; | ||
// @formatter:on | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
|
||
private final BedrockCohereCommandRChatOptions options = new BedrockCohereCommandRChatOptions(); | ||
|
||
public Builder withSearchQueriesOnly(Boolean searchQueriesOnly) { | ||
options.setSearchQueriesOnly(searchQueriesOnly); | ||
return this; | ||
} | ||
|
||
public Builder withPreamble(String preamble) { | ||
options.setPreamble(preamble); | ||
return this; | ||
} | ||
|
||
public Builder withMaxTokens(Integer maxTokens) { | ||
options.setMaxTokens(maxTokens); | ||
return this; | ||
} | ||
|
||
public Builder withTemperature(Float temperature) { | ||
options.setTemperature(temperature); | ||
return this; | ||
} | ||
|
||
public Builder withTopP(Float topP) { | ||
options.setTopP(topP); | ||
return this; | ||
} | ||
|
||
public Builder withTopK(Integer topK) { | ||
options.setTopK(topK); | ||
return this; | ||
} | ||
|
||
public Builder withPromptTruncation(PromptTruncation promptTruncation) { | ||
options.setPromptTruncation(promptTruncation); | ||
return this; | ||
} | ||
|
||
public Builder withFrequencyPenalty(Float frequencyPenalty) { | ||
options.setFrequencyPenalty(frequencyPenalty); | ||
return this; | ||
} | ||
|
||
public Builder withPresencePenalty(Float presencePenalty) { | ||
options.setPresencePenalty(presencePenalty); | ||
return this; | ||
} | ||
|
||
public Builder withSeed(Integer seed) { | ||
options.setSeed(seed); | ||
return this; | ||
} | ||
|
||
public Builder withReturnPrompt(Boolean returnPrompt) { | ||
options.setReturnPrompt(returnPrompt); | ||
return this; | ||
} | ||
|
||
public Builder withStopSequences(List<String> stopSequences) { | ||
options.setStopSequences(stopSequences); | ||
return this; | ||
} | ||
|
||
public Builder withRawPrompting(Boolean rawPrompting) { | ||
options.setRawPrompting(rawPrompting); | ||
return this; | ||
} | ||
|
||
public BedrockCohereCommandRChatOptions build() { | ||
return this.options; | ||
} | ||
|
||
} | ||
|
||
public Boolean getSearchQueriesOnly() { | ||
return searchQueriesOnly; | ||
} | ||
|
||
public void setSearchQueriesOnly(Boolean searchQueriesOnly) { | ||
this.searchQueriesOnly = searchQueriesOnly; | ||
} | ||
|
||
public String getPreamble() { | ||
return preamble; | ||
} | ||
|
||
public void setPreamble(String preamble) { | ||
this.preamble = preamble; | ||
} | ||
|
||
public Integer getMaxTokens() { | ||
return maxTokens; | ||
} | ||
|
||
public void setMaxTokens(Integer maxTokens) { | ||
this.maxTokens = maxTokens; | ||
} | ||
|
||
@Override | ||
public Float getTemperature() { | ||
return temperature; | ||
} | ||
|
||
public void setTemperature(Float temperature) { | ||
this.temperature = temperature; | ||
} | ||
|
||
@Override | ||
public Float getTopP() { | ||
return topP; | ||
} | ||
|
||
public void setTopP(Float topP) { | ||
this.topP = topP; | ||
} | ||
|
||
@Override | ||
public Integer getTopK() { | ||
return topK; | ||
} | ||
|
||
public void setTopK(Integer topK) { | ||
this.topK = topK; | ||
} | ||
|
||
public PromptTruncation getPromptTruncation() { | ||
return promptTruncation; | ||
} | ||
|
||
public void setPromptTruncation(PromptTruncation promptTruncation) { | ||
this.promptTruncation = promptTruncation; | ||
} | ||
|
||
public Float getFrequencyPenalty() { | ||
return frequencyPenalty; | ||
} | ||
|
||
public void setFrequencyPenalty(Float frequencyPenalty) { | ||
this.frequencyPenalty = frequencyPenalty; | ||
} | ||
|
||
public Float getPresencePenalty() { | ||
return presencePenalty; | ||
} | ||
|
||
public void setPresencePenalty(Float presencePenalty) { | ||
this.presencePenalty = presencePenalty; | ||
} | ||
|
||
public Integer getSeed() { | ||
return seed; | ||
} | ||
|
||
public void setSeed(Integer seed) { | ||
this.seed = seed; | ||
} | ||
|
||
public Boolean getReturnPrompt() { | ||
return returnPrompt; | ||
} | ||
|
||
public void setReturnPrompt(Boolean returnPrompt) { | ||
this.returnPrompt = returnPrompt; | ||
} | ||
|
||
public List<String> getStopSequences() { | ||
return stopSequences; | ||
} | ||
|
||
public void setStopSequences(List<String> stopSequences) { | ||
this.stopSequences = stopSequences; | ||
} | ||
|
||
public Boolean getRawPrompting() { | ||
return rawPrompting; | ||
} | ||
|
||
public void setRawPrompting(Boolean rawPrompting) { | ||
this.rawPrompting = rawPrompting; | ||
} | ||
|
||
} |
Oops, something went wrong.