diff --git a/models/spring-ai-bedrock/pom.xml b/models/spring-ai-bedrock/pom.xml index e3b79d30bd..ae7913822a 100644 --- a/models/spring-ai-bedrock/pom.xml +++ b/models/spring-ai-bedrock/pom.xml @@ -29,6 +29,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.springframework spring-web diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java new file mode 100644 index 0000000000..fe341015e5 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java @@ -0,0 +1,144 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.util.List; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.MessageToPromptConverter; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import reactor.core.publisher.Flux; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockCohereCommandRChatModel implements ChatModel, StreamingChatModel { + + private final CohereCommandRChatBedrockApi chatApi; + + private final BedrockCohereCommandRChatOptions defaultOptions; + + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + + public BedrockCohereCommandRChatModel(CohereCommandRChatBedrockApi chatApi) { + this(chatApi, BedrockCohereCommandRChatOptions.builder().build()); + } + + public BedrockCohereCommandRChatModel(CohereCommandRChatBedrockApi chatApi, + BedrockCohereCommandRChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockCohereCommandRChatModel(CohereCommandRChatBedrockApi chatApi, + BedrockCohereCommandRChatOptions options, RetryTemplate retryTemplate) { + Assert.notNull(chatApi, "CohereCommandRChatBedrockApi must not be null"); + Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + + this.chatApi = chatApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + CohereCommandRChatRequest request = this.createRequest(prompt); + + return this.retryTemplate.execute(ctx -> { + CohereCommandRChatResponse response = this.chatApi.chatCompletion(request); + + Generation generation = new Generation(response.text()); + + return new ChatResponse(List.of(generation)); + }); + } + + @Override + public Flux stream(Prompt prompt) { + CohereCommandRChatRequest request = this.createRequest(prompt); + + return this.retryTemplate.execute(ctx -> { + return this.chatApi.chatCompletionStream(request).map(g -> { + if (g.isFinished()) { + String finishReason = g.finishReason().name(); + Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); + return new ChatResponse(List.of(new Generation("") + .withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage)))); + } + return new ChatResponse(List.of(new Generation(g.text()))); + }); + }); + } + + CohereCommandRChatRequest createRequest(Prompt prompt) { + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); + + var request = CohereCommandRChatRequest.builder(promptValue) + .withSearchQueriesOnly(this.defaultOptions.getSearchQueriesOnly()) + .withPreamble(this.defaultOptions.getPreamble()) + .withMaxTokens(this.defaultOptions.getMaxTokens()) + .withTemperature(this.defaultOptions.getTemperature()) + .withTopP(this.defaultOptions.getTopP()) + .withTopK(this.defaultOptions.getTopK()) + .withPromptTruncation(this.defaultOptions.getPromptTruncation()) + .withFrequencyPenalty(this.defaultOptions.getFrequencyPenalty()) + .withPresencePenalty(this.defaultOptions.getPresencePenalty()) + .withSeed(this.defaultOptions.getSeed()) + .withReturnPrompt(this.defaultOptions.getReturnPrompt()) + .withStopSequences(this.defaultOptions.getStopSequences()) + .withRawPrompting(this.defaultOptions.getRawPrompting()) + .build(); + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, BedrockCohereCommandRChatOptions.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereCommandRChatRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + return request; + } + + @Override + public ChatOptions getDefaultOptions() { + return BedrockCohereCommandRChatOptions.fromOptions(defaultOptions); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java new file mode 100644 index 0000000000..3c5d66ab85 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java @@ -0,0 +1,293 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonInclude.Include; + +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation; +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@JsonInclude(Include.NON_NULL) +public class BedrockCohereCommandRChatOptions implements ChatOptions { + + // @formatter:off + /** + * (optional) When enabled, it will only generate potential search queries without performing + * searches or providing a response. + */ + @JsonProperty("search_queries_only") Boolean searchQueriesOnly; + /** + * (optional) Overrides the default preamble for search query generation. + */ + @JsonProperty("preamble") String preamble; + /** + * (optional) Specify the maximum number of tokens to use in the generated response. + */ + @JsonProperty("max_tokens") Integer maxTokens; + /** + * (optional) Use a lower value to decrease randomness in the response. + */ + @JsonProperty("temperature") Float temperature; + /** + * Top P. Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable. + */ + @JsonProperty("p") Float topP; + /** + * Top K. Specify the number of token choices the model uses to generate the next token. + */ + @JsonProperty("k") Integer topK; + /** + * (optional) Dictates how the prompt is constructed. + */ + @JsonProperty("prompt_truncation") PromptTruncation promptTruncation; + /** + * (optional) Used to reduce repetitiveness of generated tokens. + */ + @JsonProperty("frequency_penalty") Float frequencyPenalty; + /** + * (optional) Used to reduce repetitiveness of generated tokens. + */ + @JsonProperty("presence_penalty") Float presencePenalty; + /** + * (optional) Specify the best effort to sample tokens deterministically. + */ + @JsonProperty("seed") Integer seed; + /** + * (optional) Specify true to return the full prompt that was sent to the model. + */ + @JsonProperty("return_prompt") Boolean returnPrompt; + /** + * (optional) A list of stop sequences. + */ + @JsonProperty("stop_sequences") List stopSequences; + /** + * (optional) Specify true, to send the user’s message to the model without any preprocessing. + */ + @JsonProperty("raw_prompting") Boolean rawPrompting; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final BedrockCohereCommandRChatOptions options = new BedrockCohereCommandRChatOptions(); + + public Builder withSearchQueriesOnly(Boolean searchQueriesOnly) { + options.setSearchQueriesOnly(searchQueriesOnly); + return this; + } + + public Builder withPreamble(String preamble) { + options.setPreamble(preamble); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + options.setMaxTokens(maxTokens); + return this; + } + + public Builder withTemperature(Float temperature) { + options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Float topP) { + options.setTopP(topP); + return this; + } + + public Builder withTopK(Integer topK) { + options.setTopK(topK); + return this; + } + + public Builder withPromptTruncation(PromptTruncation promptTruncation) { + options.setPromptTruncation(promptTruncation); + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + options.setPresencePenalty(presencePenalty); + return this; + } + + public Builder withSeed(Integer seed) { + options.setSeed(seed); + return this; + } + + public Builder withReturnPrompt(Boolean returnPrompt) { + options.setReturnPrompt(returnPrompt); + return this; + } + + public Builder withStopSequences(List stopSequences) { + options.setStopSequences(stopSequences); + return this; + } + + public Builder withRawPrompting(Boolean rawPrompting) { + options.setRawPrompting(rawPrompting); + return this; + } + + public BedrockCohereCommandRChatOptions build() { + return this.options; + } + + } + + public Boolean getSearchQueriesOnly() { + return searchQueriesOnly; + } + + public void setSearchQueriesOnly(Boolean searchQueriesOnly) { + this.searchQueriesOnly = searchQueriesOnly; + } + + public String getPreamble() { + return preamble; + } + + public void setPreamble(String preamble) { + this.preamble = preamble; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Float getTemperature() { + return temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public PromptTruncation getPromptTruncation() { + return promptTruncation; + } + + public void setPromptTruncation(PromptTruncation promptTruncation) { + this.promptTruncation = promptTruncation; + } + + public Float getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Float getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public Boolean getReturnPrompt() { + return returnPrompt; + } + + public void setReturnPrompt(Boolean returnPrompt) { + this.returnPrompt = returnPrompt; + } + + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + public Boolean getRawPrompting() { + return rawPrompting; + } + + public void setRawPrompting(Boolean rawPrompting) { + this.rawPrompting = rawPrompting; + } + + public static BedrockCohereCommandRChatOptions fromOptions(BedrockCohereCommandRChatOptions fromOptions) { + return builder().withSearchQueriesOnly(fromOptions.getSearchQueriesOnly()) + .withPreamble(fromOptions.getPreamble()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withPromptTruncation(fromOptions.getPromptTruncation()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withSeed(fromOptions.getSeed()) + .withReturnPrompt(fromOptions.getReturnPrompt()) + .withStopSequences(fromOptions.getStopSequences()) + .withRawPrompting(fromOptions.getRawPrompting()) + .build(); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApi.java new file mode 100644 index 0000000000..f7ef5b5ba6 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApi.java @@ -0,0 +1,502 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// @formatter:off +package org.springframework.ai.bedrock.cohere.api; + +import java.time.Duration; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.databind.ObjectMapper; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.api.AbstractBedrockApi; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatStreamingResponse; + +/** + * Java client for the Bedrock Cohere command R chat model. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + * + * @author Wei Jiang + * @since 1.0.0 + */ +public class CohereCommandRChatBedrockApi + extends AbstractBedrockApi { + + /** + * Create a new CohereCommandRChatBedrockApi instance using the default credentials provider chain, the default object + * mapper, default temperature and topP values. + * + * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. + * @param region The AWS region to use. + */ + public CohereCommandRChatBedrockApi(String modelId, String region) { + super(modelId, region); + } + + /** + * Create a new CohereCommandRChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + */ + public CohereCommandRChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, + ObjectMapper objectMapper) { + super(modelId, credentialsProvider, region, objectMapper); + } + + /** + * Create a new CohereCommandRChatBedrockApi instance using the default credentials provider chain, the default object + * mapper, default temperature and topP values. + * + * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. + * @param region The AWS region to use. + * @param timeout The timeout to use. + */ + public CohereCommandRChatBedrockApi(String modelId, String region, Duration timeout) { + super(modelId, region, timeout); + } + + /** + * Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link CohereChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public CohereCommandRChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + + /** + * Create a new CohereCommandRChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link CohereCommandRChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public CohereCommandRChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + + /** + * CohereCommandRChatRequest encapsulates the request parameters for the Cohere command R model. + * + * @param message Text input for the model to respond to. + * @param chatHistory (optional) A list of previous messages between the user and the model. + * @param documents (optional) A list of texts that the model can cite to generate a more accurate reply. + * @param searchQueriesOnly (optional) When enabled, it will only generate potential search queries without performing + * searches or providing a response. + * @param preamble (optional) Overrides the default preamble for search query generation. + * @param maxTokens (optional) Specify the maximum number of tokens to use in the generated response. + * @param temperature (optional) Use a lower value to decrease randomness in the response. + * @param topP (optional) Top P. Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable. + * @param topK (optional) Top K. Specify the number of token choices the model uses to generate the next token. + * @param promptTruncation (optional) Dictates how the prompt is constructed. + * @param frequencyPenalty (optional) Used to reduce repetitiveness of generated tokens. + * @param presencePenalty (optional) Used to reduce repetitiveness of generated tokens. + * @param seed (optional) Specify the best effort to sample tokens deterministically. + * @param returnPrompt (optional) Specify true to return the full prompt that was sent to the model. + * @param stopSequences (optional) A list of stop sequences. + * @param rawPrompting (optional) Specify true, to send the user’s message to the model without any preprocessing. + */ + @JsonInclude(Include.NON_NULL) + public record CohereCommandRChatRequest( + @JsonProperty("message") String message, + @JsonProperty("chat_history") List chatHistory, + @JsonProperty("documents") List documents, + @JsonProperty("search_queries_only") Boolean searchQueriesOnly, + @JsonProperty("preamble") String preamble, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("temperature") Float temperature, + @JsonProperty("p") Float topP, + @JsonProperty("k") Integer topK, + @JsonProperty("prompt_truncation") PromptTruncation promptTruncation, + @JsonProperty("frequency_penalty") Float frequencyPenalty, + @JsonProperty("presence_penalty") Float presencePenalty, + @JsonProperty("seed") Integer seed, + @JsonProperty("return_prompt") Boolean returnPrompt, + @JsonProperty("stop_sequences") List stopSequences, + @JsonProperty("raw_prompting") Boolean rawPrompting) { + + /** + * The text that the model can cite to generate a more accurate reply. + * + * @param title The title of the document. + * @param snippet The snippet of the document. + */ + @JsonInclude(Include.NON_NULL) + public record Document( + @JsonProperty("title") String title, + @JsonProperty("snippet") String snippet) {} + + /** + * Specifies how the prompt is constructed. + */ + public enum PromptTruncation { + /** + * Some elements from chat_history and documents will be dropped to construct a prompt + * that fits within the model's context length limit. + */ + AUTO_PRESERVE_ORDER, + /** + * (Default) No elements will be dropped. + */ + OFF + } + + /** + * Get CohereCommandRChatRequest builder. + * + * @param message Compulsory request prompt parameter. + * @return CohereCommandRChatRequest builder. + */ + public static Builder builder(String message) { + return new Builder(message); + } + + /** + * Builder for the CohereCommandRChatRequest. + */ + public static class Builder { + private final String message; + private List chatHistory; + private List documents; + private Boolean searchQueriesOnly; + private String preamble; + private Integer maxTokens; + private Float temperature; + private Float topP; + private Integer topK; + private PromptTruncation promptTruncation; + private Float frequencyPenalty; + private Float presencePenalty; + private Integer seed; + private Boolean returnPrompt; + private List stopSequences; + private Boolean rawPrompting; + + public Builder(String message) { + this.message = message; + } + + public Builder withChatHistory(List chatHistory) { + this.chatHistory = chatHistory; + return this; + } + + public Builder withDocuments(List documents) { + this.documents = documents; + return this; + } + + public Builder withSearchQueriesOnly(Boolean searchQueriesOnly) { + this.searchQueriesOnly = searchQueriesOnly; + return this; + } + + public Builder withPreamble(String preamble) { + this.preamble = preamble; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder withTemperature(Float temperature) { + this.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder withPromptTruncation(PromptTruncation promptTruncation) { + this.promptTruncation = promptTruncation; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder withSeed(Integer seed) { + this.seed = seed; + return this; + } + + public Builder withReturnPrompt(Boolean returnPrompt) { + this.returnPrompt = returnPrompt; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder withRawPrompting(Boolean rawPrompting) { + this.rawPrompting = rawPrompting; + return this; + } + + public CohereCommandRChatRequest build() { + return new CohereCommandRChatRequest( + message, + chatHistory, + documents, + searchQueriesOnly, + preamble, + maxTokens, + temperature, + topP, + topK, + promptTruncation, + frequencyPenalty, + presencePenalty, + seed, + returnPrompt, + stopSequences, + rawPrompting + + ); + } + } + } + + /** + * CohereCommandRChatResponse encapsulates the response parameters for the Cohere command R model. + * + * @param id Unique identifier for chat completion. + * @param text The model’s response to chat message input. + * @param generationId Unique identifier for chat completion, used with Feedback endpoint on Cohere’s platform. + * @param finishReason The reason why the model stopped generating output. + * @param chatHistory A list of previous messages between the user and the model. + * @param metadata API usage data. + */ + @JsonInclude(Include.NON_NULL) + public record CohereCommandRChatResponse( + @JsonProperty("response_id") String id, + @JsonProperty("text") String text, + @JsonProperty("generation_id") String generationId, + @JsonProperty("finish_reason") FinishReason finishReason, + @JsonProperty("chat_history") List chatHistory, + @JsonProperty("meta") Metadata metadata) { + + /** + * API usage data. + * + * @param apiVersion The API version. + * @param billedUnits The billed units. + * @param tokens The tokens units. + */ + @JsonInclude(Include.NON_NULL) + public record Metadata( + @JsonProperty("api_version") ApiVersion apiVersion, + @JsonProperty("billed_units") BilledUnits billedUnits, + @JsonProperty("tokens") Tokens tokens) { + + /** + * The API version. + * + * @param version The API version. + */ + @JsonInclude(Include.NON_NULL) + public record ApiVersion(@JsonProperty("version") String version) {} + + /** + * The billed units. + * + * @param inputTokens The number of input tokens that were billed. + * @param outputTokens The number of output tokens that were billed. + */ + @JsonInclude(Include.NON_NULL) + public record BilledUnits( + @JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens) {} + + /** + * The tokens units. + * + * @param inputTokens The number of input tokens. + * @param outputTokens The number of output tokens. + */ + @JsonInclude(Include.NON_NULL) + public record Tokens( + @JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens) {} + + } + + } + + /** + * CohereCommandRChatStreamingResponse encapsulates the streaming response parameters for the Cohere command R model. + * https://docs.cohere.com/docs/streaming#stream-events + * + * @param eventType The event type of stream response. + * @param text The model’s response to chat message input. + * @param isFinished Specify whether the streaming session is finished + * @param finishReason The reason why the model stopped generating output. + * @param response The final response about this stream invocation. + * @param amazonBedrockInvocationMetrics Metrics about the model invocation. + */ + @JsonInclude(Include.NON_NULL) + public record CohereCommandRChatStreamingResponse( + @JsonProperty("event_type") String eventType, + @JsonProperty("text") String text, + @JsonProperty("is_finished") Boolean isFinished, + @JsonProperty("finish_reason") FinishReason finishReason, + @JsonProperty("response") CohereCommandRChatResponse response, + @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) {} + + /** + * Previous messages between the user and the model. + * + * @param role The role for the message. Valid values are USER or CHATBOT. tokens. + * @param message Text contents of the message. + */ + @JsonInclude(Include.NON_NULL) + public record ChatHistory( + @JsonProperty("role") Role role, + @JsonProperty("message") String message) { + + /** + * The role for the message. + */ + public enum Role { + + /** + * User message. + */ + USER, + + /** + * Chatbot message. + */ + CHATBOT + + } + + } + + /** + * The reason why the model stopped generating output. + */ + public enum FinishReason { + + /** + * The completion reached the end of generation token, ensure this is the finish reason for best performance. + */ + COMPLETE, + + /** + * The generation could not be completed due to our content filters. + */ + ERROR_TOXIC, + + /** + * The generation could not be completed because the model’s context limit was reached. + */ + ERROR_LIMIT, + + /** + * The generation could not be completed due to an error. + */ + ERROR, + + /** + * The generation could not be completed because it was stopped by the user. + */ + USER_CANCEL, + + /** + * The generation could not be completed because the user specified a max_tokens limit in the request and this limit was reached. May not result in best performance. + */ + MAX_TOKENS + + } + + /** + * Cohere command R models version. + */ + public enum CohereCommandRChatModel { + + /** + * cohere.command-r-v1:0 + */ + COHERE_COMMAND_R_V1("cohere.command-r-v1:0"), + + /** + * cohere.command-r-plus-v1:0 + */ + COHERE_COMMAND_R_PLUS_V1("cohere.command-r-plus-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; + } + + CohereCommandRChatModel(String value) { + this.id = value; + } + } + + @Override + public CohereCommandRChatResponse chatCompletion(CohereCommandRChatRequest request) { + return this.internalInvocation(request, CohereCommandRChatResponse.class); + } + + @Override + public Flux chatCompletionStream(CohereCommandRChatRequest request) { + return this.internalInvocationStream(request, CohereCommandRChatStreamingResponse.class); + } + +} +//@formatter:on diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatCreateRequestTests.java new file mode 100644 index 0000000000..3c4c4cc00d --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatCreateRequestTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.time.Duration; +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation; +import org.springframework.ai.chat.prompt.Prompt; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + */ +public class BedrockCohereCommandRChatCreateRequestTests { + + private CohereCommandRChatBedrockApi chatApi = new CohereCommandRChatBedrockApi( + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); + + @Test + public void createRequestWithChatOptions() { + + var client = new BedrockCohereCommandRChatModel(chatApi, + BedrockCohereCommandRChatOptions.builder() + .withSearchQueriesOnly(true) + .withPreamble("preamble") + .withMaxTokens(678) + .withTemperature(66.6f) + .withTopK(66) + .withTopP(0.66f) + .withPromptTruncation(PromptTruncation.OFF) + .withFrequencyPenalty(0.1f) + .withPresencePenalty(0.2f) + .withSeed(1000) + .withReturnPrompt(false) + .withStopSequences(List.of("stop1", "stop2")) + .withRawPrompting(false) + .build()); + + CohereCommandRChatRequest request = client.createRequest(new Prompt("Test message content")); + + assertThat(request.message()).isNotEmpty(); + assertThat(request.searchQueriesOnly()).isTrue(); + assertThat(request.preamble()).isEqualTo("preamble"); + assertThat(request.maxTokens()).isEqualTo(678); + assertThat(request.temperature()).isEqualTo(66.6f); + assertThat(request.topK()).isEqualTo(66); + assertThat(request.topP()).isEqualTo(0.66f); + assertThat(request.promptTruncation()).isEqualTo(PromptTruncation.OFF); + assertThat(request.frequencyPenalty()).isEqualTo(0.1f); + assertThat(request.presencePenalty()).isEqualTo(0.2f); + assertThat(request.seed()).isEqualTo(1000); + assertThat(request.returnPrompt()).isEqualTo(false); + assertThat(request.stopSequences()).containsExactly("stop1", "stop2"); + assertThat(request.rawPrompting()).isEqualTo(false); + + request = client.createRequest(new Prompt("Test message content", + BedrockCohereCommandRChatOptions.builder() + .withSearchQueriesOnly(false) + .withPreamble("preamble") + .withMaxTokens(999) + .withTemperature(99.9f) + .withTopK(99) + .withTopP(0.99f) + .withPromptTruncation(PromptTruncation.OFF) + .withFrequencyPenalty(0.9f) + .withPresencePenalty(0.9f) + .withSeed(9999) + .withReturnPrompt(true) + .withStopSequences(List.of("stop1", "stop2")) + .withRawPrompting(true) + .build())); + + assertThat(request.message()).isNotEmpty(); + assertThat(request.searchQueriesOnly()).isFalse(); + assertThat(request.preamble()).isEqualTo("preamble"); + assertThat(request.maxTokens()).isEqualTo(999); + assertThat(request.temperature()).isEqualTo(99.9f); + assertThat(request.topK()).isEqualTo(99); + assertThat(request.topP()).isEqualTo(0.99f); + assertThat(request.promptTruncation()).isEqualTo(PromptTruncation.OFF); + assertThat(request.frequencyPenalty()).isEqualTo(0.9f); + assertThat(request.presencePenalty()).isEqualTo(0.9f); + assertThat(request.seed()).isEqualTo(9999); + assertThat(request.returnPrompt()).isEqualTo(true); + assertThat(request.stopSequences()).containsExactly("stop1", "stop2"); + assertThat(request.rawPrompting()).isEqualTo(true); + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModelIT.java new file mode 100644 index 0000000000..56221cfc49 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModelIT.java @@ -0,0 +1,218 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere; + +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +class BedrockCohereCommandRChatModelIT { + + @Autowired + private BedrockCohereCommandRChatModel chatModel; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void multipleStreamAttempts() { + + Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + + String joke1 = joke1Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + String joke2 = joke2Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(joke1).isNotBlank(); + assertThat(joke2).isNotBlank(); + } + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors.", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Remove Markdown code blocks from the output. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public CohereCommandRChatBedrockApi cohereApi() { + return new CohereCommandRChatBedrockApi(CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockCohereCommandRChatModel cohereCommandRChatModel(CohereCommandRChatBedrockApi cohereApi) { + return new BedrockCohereCommandRChatModel(cohereApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApiIT.java new file mode 100644 index 0000000000..536a632f5e --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApiIT.java @@ -0,0 +1,129 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere.api; + +import java.time.Duration; +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.ChatHistory; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.ChatHistory.Role; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.Document; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatStreamingResponse; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.FinishReason; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class CohereCommandRChatBedrockApiIT { + + private CohereCommandRChatBedrockApi cohereChatApi = new CohereCommandRChatBedrockApi( + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); + + @Test + public void requestBuilder() { + + CohereCommandRChatRequest request1 = new CohereCommandRChatRequest( + "What is the capital of Bulgaria and what is the size? What it the national anthem?", + List.of(new ChatHistory(Role.CHATBOT, "message")), List.of(new Document("title", "snippet")), false, + "preamble", 100, 0.5f, 0.6f, 15, PromptTruncation.AUTO_PRESERVE_ORDER, 0.8f, 0.9f, 5050, false, + List.of("stop_sequence"), false); + + var request2 = CohereCommandRChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withChatHistory(List.of(new ChatHistory(Role.CHATBOT, "message"))) + .withDocuments(List.of(new Document("title", "snippet"))) + .withSearchQueriesOnly(false) + .withPreamble("preamble") + .withMaxTokens(100) + .withTemperature(0.5f) + .withTopP(0.6f) + .withTopK(15) + .withPromptTruncation(PromptTruncation.AUTO_PRESERVE_ORDER) + .withFrequencyPenalty(0.8f) + .withPresencePenalty(0.9f) + .withSeed(5050) + .withReturnPrompt(false) + .withStopSequences(List.of("stop_sequence")) + .withRawPrompting(false) + .build(); + + assertThat(request1).isEqualTo(request2); + } + + @Test + public void chatCompletion() { + + var request = CohereCommandRChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withTemperature(0.5f) + .withTopP(0.8f) + .withTopK(15) + .withMaxTokens(2000) + .build(); + + CohereCommandRChatResponse response = cohereChatApi.chatCompletion(request); + + assertThat(response).isNotNull(); + assertThat(response.finishReason()).isEqualTo(FinishReason.COMPLETE); + assertThat(response.text()).isNotEmpty(); + assertThat(response.id()).isNotEmpty(); + assertThat(response.generationId()).isNotEmpty(); + assertThat(response.chatHistory()).isNotNull(); + assertThat(response.chatHistory().size()).isEqualTo(2); + } + + @Test + public void chatCompletionStream() { + + var request = CohereCommandRChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withTemperature(0.5f) + .withTopP(0.8f) + .withTopK(15) + .withMaxTokens(50) + .withStopSequences(List.of("END")) + .build(); + + Flux responseStream = cohereChatApi.chatCompletionStream(request); + List responses = responseStream.collectList().block(); + + assertThat(responses).isNotNull(); + assertThat(responses).hasSizeGreaterThan(10); + assertThat(responses.get(0).text()).isNotEmpty(); + CohereCommandRChatStreamingResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.text()).isNull(); + assertThat(lastResponse.isFinished()).isTrue(); + assertThat(lastResponse.finishReason()).isEqualTo(FinishReason.MAX_TOKENS); + assertThat(lastResponse.amazonBedrockInvocationMetrics()).isNotNull(); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 120a8e5c8b..0287e7a822 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -14,6 +14,7 @@ **** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic2] **** xref:api/chat/bedrock/bedrock-llama.adoc[Llama] **** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere] +**** xref:api/chat/bedrock/bedrock-coherecommandr.adoc[CohereCommandR] **** xref:api/chat/bedrock/bedrock-titan.adoc[Titan] **** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2] *** xref:api/chat/huggingface.adoc[HuggingFace] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc index f8b2b2062a..b8f744f49c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc @@ -104,6 +104,7 @@ For more information, refer to the documentation below for each supported model. * xref:api/chat/bedrock/bedrock-anthropic3.adoc[Spring AI Bedrock Anthropic 3 Chat]: `spring.ai.bedrock.anthropic.chat.enabled=true` * xref:api/chat/bedrock/bedrock-llama.adoc[Spring AI Bedrock Llama Chat]: `spring.ai.bedrock.llama.chat.enabled=true` * xref:api/chat/bedrock/bedrock-cohere.adoc[Spring AI Bedrock Cohere Chat]: `spring.ai.bedrock.cohere.chat.enabled=true` +* xref:api/chat/bedrock/bedrock-coherecommandr.adoc[Spring AI Bedrock Cohere Command R Chat]: `spring.ai.bedrock.coherecommandr.chat.enabled=true` * xref:api/embeddings/bedrock-cohere-embedding.adoc[Spring AI Bedrock Cohere Embeddings]: `spring.ai.bedrock.cohere.embedding.enabled=true` * xref:api/chat/bedrock/bedrock-titan.adoc[Spring AI Bedrock Titan Chat]: `spring.ai.bedrock.titan.chat.enabled=true` * xref:api/embeddings/bedrock-titan-embedding.adoc[Spring AI Bedrock Titan Embeddings]: `spring.ai.bedrock.titan.embedding.enabled=true` diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-coherecommandr.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-coherecommandr.adoc new file mode 100644 index 0000000000..8a69e8fdbe --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-coherecommandr.adoc @@ -0,0 +1,279 @@ += Cohere Command R Chat + +Provides Bedrock Cohere Command R Chat model. +Integrate generative AI capabilities into essential apps and workflows that improve business outcomes. + +The https://aws.amazon.com/bedrock/cohere-command-embed/[AWS Bedrock Cohere Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model. + +== Prerequisites + +Refer to the xref:api/bedrock.adoc[Spring AI documentation on Amazon Bedrock] for setting up API access. + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Add the `spring-ai-bedrock-ai-spring-boot-starter` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock-ai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock-ai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Enable Cohere Command R Chat Support + +By default the Cohere Command R model is disabled. +To enable it set the `spring.ai.bedrock.coherecommandr.chat.enabled` property to `true`. +Exporting environment variable is one way to set this configuration property: + +[source,shell] +---- +export SPRING_AI_BEDROCK_COHERECOMMANDR_CHAT_ENABLED=true +---- + +=== Chat Properties + +The prefix `spring.ai.bedrock.aws` is the property prefix to configure the connection to AWS Bedrock. + +[cols="3,3,3"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.aws.region | AWS region to use. | us-east-1 +| spring.ai.bedrock.aws.timeout | AWS timeout to use. | 5m +| spring.ai.bedrock.aws.access-key | AWS access key. | - +| spring.ai.bedrock.aws.secret-key | AWS secret key. | - +|==== + +The prefix `spring.ai.bedrock.coherecommandr.chat` is the property prefix that configures the chat model implementation for Cohere Command R. + +[cols="2,5,1"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.coherecommandr.chat.enabled | Enable or disable support for Cohere Command R | false +| spring.ai.bedrock.coherecommandr.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApi.java#L465C14-L489C29[CohereCommandRChatModel] for the supported models. | cohere.command-r-plus-v1:0 +| spring.ai.bedrock.coherecommandr.chat.options.searchQueriesOnly | When enabled, it will only generate potential search queries without performing searches or providing a response. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.preamble | Overrides the default preamble for search query generation. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.maxToken | Specify the maximum number of tokens to use in the generated response. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.7 +| spring.ai.bedrock.coherecommandr.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.topK | Specify the number of token choices the model uses to generate the next token | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.promptTruncation | Dictates how the prompt is constructed. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.frequencyPenalty | Used to reduce repetitiveness of generated tokens. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.presencePenalty | Used to reduce repetitiveness of generated tokens. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.seed | Specify the best effort to sample tokens deterministically. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.returnPrompt | Specify true to return the full prompt that was sent to the model. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.stopSequences | Configure up to four sequences that the model recognizes. | AWS Bedrock default +| spring.ai.bedrock.coherecommandr.chat.options.rawPrompting | Specify true, to send the user’s message to the model without any preprocessing. | AWS Bedrock default +|==== + +Look at the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApi.java#L465C14-L489C29[CohereCommandRChatModel] for other model IDs. +Supported values are: `cohere.command-r-plus-v1:0` and `cohere.command-r-v1:0`. +Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html[AWS Bedrock documentation for base model IDs]. + +TIP: All properties prefixed with `spring.ai.bedrock.coherecommandr.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java[BedrockCohereCommandRChatOptions.java] provides model configurations, such as temperature, topK, topP, etc. + +On start-up, the default options can be configured with the `BedrockCohereCommandRChatModel(api, options)` constructor or the `spring.ai.bedrock.coherecommandr.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + BedrockCohereCommandRChatOptions.builder() + .withTemperature(0.4) + .build() + )); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatOptions.java[BedrockCohereCommandRChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-bedrock-ai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Cohere Command R chat model: + +[source] +---- +spring.ai.bedrock.aws.region=eu-central-1 +spring.ai.bedrock.aws.timeout=1000ms +spring.ai.bedrock.aws.access-key=${AWS_ACCESS_KEY_ID} +spring.ai.bedrock.aws.secret-key=${AWS_SECRET_ACCESS_KEY} + +spring.ai.bedrock.coherecommandr.chat.enabled=true +spring.ai.bedrock.coherecommandr.chat.options.temperature=0.8 +---- + +TIP: replace the `regions`, `access-key` and `secret-key` with your AWS credentials. + +This will create a `BedrockCohereCommandRChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final BedrockCohereCommandRChatModel chatModel; + + @Autowired + public ChatController(BedrockCohereCommandRChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + Prompt prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java[BedrockCohereCommandRChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the Bedrock Cohere Command R service. + +Add the `spring-ai-bedrock` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereCommandRChatModel.java[BedrockCohereCommandRChatModel] and use it for text generations: + +[source,java] +---- +CohereCommandRChatBedrockApi api = new CohereCommandRChatBedrockApi(CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id(), + new ObjectMapper(), + Duration.ofMillis(1000L)); + +BedrockCohereCommandRChatModel chatModel = new BedrockCohereCommandRChatModel(api, + BedrockCohereCommandRChatOptions.builder() + .withTemperature(0.6f) + .withTopK(10) + .withTopP(0.5f) + .withMaxTokens(678) + .build() + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux response = chatModel.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +== Low-level CohereCommandRChatBedrockApi Client [[low-level-api]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereCommandRChatBedrockApi.java[CohereCommandRChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html[Cohere Command R models]. + +Following class diagram illustrates the CohereCommandRChatBedrockApi interface and building blocks: + +image::bedrock/bedrock-cohere-chat-low-level-api.jpg[align="center", width="800px"] + +The CohereCommandRChatBedrockApi supports the `cohere.command-r-v1:0` and `cohere.command-r-plus-v1:0` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) requests. + +Here is a simple snippet how to use the api programmatically: + +[source,java] +---- +CohereCommandRChatBedrockApi cohereCommandRChatApi = new CohereCommandRChatBedrockApi( + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + Region.US_EAST_1.id(), + Duration.ofMillis(1000L)); + +var request = CohereCommandRChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withChatHistory(List.of(new ChatHistory(Role.CHATBOT, "message"))) + .withDocuments(List.of(new Document("title", "snippet"))) + .withSearchQueriesOnly(false) + .withPreamble("preamble") + .withMaxTokens(100) + .withTemperature(0.5f) + .withTopP(0.6f) + .withTopK(15) + .withPromptTruncation(PromptTruncation.AUTO_PRESERVE_ORDER) + .withFrequencyPenalty(0.8f) + .withPresencePenalty(0.9f) + .withSeed(5050) + .withReturnPrompt(false) + .withStopSequences(List.of("stop_sequence")) + .withRawPrompting(false) + .build(); + +CohereCommandRChatResponse response = cohereCommandRChatApi.chatCompletion(request); + +var request = CohereCommandRChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withChatHistory(List.of(new ChatHistory(Role.CHATBOT, "message"))) + .withDocuments(List.of(new Document("title", "snippet"))) + .withSearchQueriesOnly(false) + .withPreamble("preamble") + .withMaxTokens(100) + .withTemperature(0.5f) + .withTopP(0.6f) + .withTopK(15) + .withPromptTruncation(PromptTruncation.AUTO_PRESERVE_ORDER) + .withFrequencyPenalty(0.8f) + .withPresencePenalty(0.9f) + .withSeed(5050) + .withReturnPrompt(false) + .withStopSequences(List.of("stop_sequence")) + .withRawPrompting(false) + .build(); + +Flux responseStream = cohereCommandRChatApi.chatCompletionStream(request); +List responses = responseStream.collectList().block(); +---- + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc index 224df22be4..1378add95c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc @@ -149,6 +149,7 @@ Each of the following sections in the documentation shows which dependencies you ** xref:api/chat/vertexai-gemini-chat.adoc[Google Vertex AI Gemini Chat Completion] (streaming, multi-modality & function-calling support) ** xref:api/bedrock.adoc[Amazon Bedrock] *** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere Chat Completion] +*** xref:api/chat/bedrock/bedrock-coherecommandr.adoc[Cohere Command R Chat Completion] *** xref:api/chat/bedrock/bedrock-llama.adoc[Llama Chat Completion] *** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] *** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfiguration.java new file mode 100644 index 0000000000..d559cb5919 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfiguration.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + +/** + * {@link AutoConfiguration Auto-configuration} for Bedrock Cohere Command R Chat Client. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(CohereCommandRChatBedrockApi.class) +@EnableConfigurationProperties({ BedrockCohereCommandRChatProperties.class, BedrockAwsConnectionProperties.class }) +@ConditionalOnProperty(prefix = BedrockCohereCommandRChatProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true") +@Import(BedrockAwsConnectionConfiguration.class) +public class BedrockCohereCommandRChatAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) + public CohereCommandRChatBedrockApi cohereCommandRChatApi(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider, BedrockCohereCommandRChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new CohereCommandRChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), + new ObjectMapper(), awsProperties.getTimeout()); + } + + @Bean + @ConditionalOnBean(CohereCommandRChatBedrockApi.class) + public BedrockCohereCommandRChatModel cohereCommandRChatModel(CohereCommandRChatBedrockApi cohereCommandRChatApi, + BedrockCohereCommandRChatProperties properties, RetryTemplate retryTemplate) { + return new BedrockCohereCommandRChatModel(cohereCommandRChatApi, properties.getOptions(), retryTemplate); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatProperties.java new file mode 100644 index 0000000000..851fb9c555 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatProperties.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere; + +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatOptions; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Bedrock Cohere Command R Chat autoconfiguration properties. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@ConfigurationProperties(BedrockCohereCommandRChatProperties.CONFIG_PREFIX) +public class BedrockCohereCommandRChatProperties { + + public static final String CONFIG_PREFIX = "spring.ai.bedrock.coherecommandr.chat"; + + /** + * Enable Bedrock Cohere Command R Chat Client. False by default. + */ + private boolean enabled = false; + + /** + * Bedrock Cohere Command R Chat generative name. Defaults to + * 'cohere.command-r-plus-v1:0'. + */ + private String model = CohereCommandRChatBedrockApi.CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(); + + @NestedConfigurationProperty + private BedrockCohereCommandRChatOptions options = BedrockCohereCommandRChatOptions.builder().build(); + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public BedrockCohereCommandRChatOptions getOptions() { + return this.options; + } + + public void setOptions(BedrockCohereCommandRChatOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index c744be669c..dd008fbbbf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -8,6 +8,7 @@ org.springframework.ai.autoconfigure.vertexai.gemini.VertexAiGeminiAutoConfigura org.springframework.ai.autoconfigure.bedrock.jurrasic2.BedrockAi21Jurassic2ChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.llama.BedrockLlamaChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereChatAutoConfiguration +org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereCommandRChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereEmbeddingAutoConfiguration org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.anthropic3.BedrockAnthropic3ChatAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfigurationIT.java new file mode 100644 index 0000000000..4bec023b46 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereCommandRChatAutoConfigurationIT.java @@ -0,0 +1,189 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.cohere; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.cohere.BedrockCohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatModel; +import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockCohereCommandRChatAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.coherecommandr.chat.model=" + CohereCommandRChatModel.COHERE_COMMAND_R_PLUS_V1.id(), + "spring.ai.bedrock.coherecommandr.chat.options.temperature=0.5", + "spring.ai.bedrock.coherecommandr.chat.options.maxTokens=500") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereCommandRChatAutoConfiguration.class)); + + private final Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + private final UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + @Test + public void chatCompletion() { + contextRunner.run(context -> { + BedrockCohereCommandRChatModel cohereCommandRChatModel = context + .getBean(BedrockCohereCommandRChatModel.class); + ChatResponse response = cohereCommandRChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + }); + } + + @Test + public void chatCompletionStreaming() { + contextRunner.run(context -> { + + BedrockCohereCommandRChatModel cohereCommandRChatModel = context + .getBean(BedrockCohereCommandRChatModel.class); + + Flux response = cohereCommandRChatModel + .stream(new Prompt(List.of(userMessage, systemMessage))); + + List responses = response.collectList().block(); + assertThat(responses.size()).isGreaterThan(2); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + }); + } + + @Test + public void propertiesTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", + "spring.ai.bedrock.coherecommandr.chat.model=MODEL_XYZ", + "spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(), + "spring.ai.bedrock.coherecommandr.chat.options.searchQueriesOnly=true", + "spring.ai.bedrock.coherecommandr.chat.options.preamble=preamble", + "spring.ai.bedrock.coherecommandr.chat.options.maxTokens=123", + "spring.ai.bedrock.coherecommandr.chat.options.temperature=0.55", + "spring.ai.bedrock.coherecommandr.chat.options.topP=0.55", + "spring.ai.bedrock.coherecommandr.chat.options.topK=10", + "spring.ai.bedrock.coherecommandr.chat.options.promptTruncation=AUTO_PRESERVE_ORDER", + "spring.ai.bedrock.coherecommandr.chat.options.frequencyPenalty=0.55", + "spring.ai.bedrock.coherecommandr.chat.options.presencePenalty=0.66", + "spring.ai.bedrock.coherecommandr.chat.options.seed=555555", + "spring.ai.bedrock.coherecommandr.chat.options.returnPrompt=true", + "spring.ai.bedrock.coherecommandr.chat.options.stopSequences=END1,END2", + "spring.ai.bedrock.coherecommandr.chat.options.rawPrompting=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(BedrockCohereCommandRChatProperties.class); + var aswProperties = context.getBean(BedrockAwsConnectionProperties.class); + + assertThat(chatProperties.isEnabled()).isTrue(); + assertThat(aswProperties.getRegion()).isEqualTo(Region.EU_CENTRAL_1.id()); + assertThat(chatProperties.getModel()).isEqualTo("MODEL_XYZ"); + + assertThat(chatProperties.getOptions().getSearchQueriesOnly()).isTrue(); + assertThat(chatProperties.getOptions().getPreamble()).isEqualTo("preamble"); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopK()).isEqualTo(10); + assertThat(chatProperties.getOptions().getPromptTruncation()) + .isEqualTo(PromptTruncation.AUTO_PRESERVE_ORDER); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0.66f); + assertThat(chatProperties.getOptions().getSeed()).isEqualTo(555555); + assertThat(chatProperties.getOptions().getReturnPrompt()).isTrue(); + assertThat(chatProperties.getOptions().getStopSequences()).isEqualTo(List.of("END1", "END2")); + assertThat(chatProperties.getOptions().getRawPrompting()).isTrue(); + + assertThat(aswProperties.getAccessKey()).isEqualTo("ACCESS_KEY"); + assertThat(aswProperties.getSecretKey()).isEqualTo("SECRET_KEY"); + }); + } + + @Test + public void chatCompletionDisabled() { + + // It is disabled by default + new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockCohereCommandRChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockCohereCommandRChatModel.class)).isEmpty(); + }); + + // Explicitly enable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockCohereCommandRChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(BedrockCohereCommandRChatModel.class)).isNotEmpty(); + }); + + // Explicitly disable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.coherecommandr.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockCohereCommandRChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockCohereCommandRChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockCohereCommandRChatModel.class)).isEmpty(); + }); + } + +}