From 24839d2a3708a118208a7d83b2bf7305c53ca9af Mon Sep 17 00:00:00 2001 From: GR Date: Wed, 22 May 2024 18:08:24 +0800 Subject: [PATCH] feat: add QianFan model client --- models/spring-ai-qianfan/README.md | 3 + models/spring-ai-qianfan/pom.xml | 59 ++ .../ai/qianfan/QianFanChatClient.java | 199 +++++++ .../ai/qianfan/QianFanChatOptions.java | 292 ++++++++++ .../ai/qianfan/QianFanEmbeddingClient.java | 156 +++++ .../ai/qianfan/QianFanEmbeddingOptions.java | 88 +++ .../ai/qianfan/aot/QianFanRuntimeHints.java | 42 ++ .../ai/qianfan/api/ApiUtils.java | 37 ++ .../ai/qianfan/api/QianFanApi.java | 550 ++++++++++++++++++ .../qianfan/api/auth/AccessTokenResponse.java | 12 + .../qianfan/api/auth/QianFanAccessToken.java | 70 +++ .../api/auth/QianFanAuthenticator.java | 71 +++ .../resources/META-INF/spring/aot.factories | 2 + .../qianfan/ChatCompletionRequestTests.java | 53 ++ .../ai/qianfan/QianFanTestConfiguration.java | 63 ++ .../ai/qianfan/api/QianFanApiIT.java | 81 +++ .../ai/qianfan/api/QianFanRetryTests.java | 172 ++++++ .../test/resources/prompts/system-message.st | 3 + pom.xml | 2 + spring-ai-bom/pom.xml | 11 + .../src/main/antora/modules/ROOT/nav.adoc | 2 + .../ROOT/pages/api/chat/qianfan-chat.adoc | 254 ++++++++ .../api/embeddings/qianfan-embeddings.adoc | 202 +++++++ spring-ai-spring-boot-autoconfigure/pom.xml | 7 + .../qianfan/QianFanAutoConfiguration.java | 102 ++++ .../qianfan/QianFanChatProperties.java | 62 ++ .../qianfan/QianFanConnectionProperties.java | 32 + .../qianfan/QianFanEmbeddingProperties.java | 68 +++ .../qianfan/QianFanParentProperties.java | 53 ++ .../qianfan/QianFanAutoConfigurationIT.java | 97 +++ .../qianfan/QianFanPropertiesTests.java | 299 ++++++++++ .../spring-ai-starter-qianfan/pom.xml | 42 ++ 32 files changed, 3186 insertions(+) create mode 100644 models/spring-ai-qianfan/README.md create mode 100644 models/spring-ai-qianfan/pom.xml create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatClient.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingClient.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/ApiUtils.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java create mode 100644 models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java create mode 100644 models/spring-ai-qianfan/src/main/resources/META-INF/spring/aot.factories create mode 100644 models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java create mode 100644 models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java create mode 100644 models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java create mode 100644 models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java create mode 100644 models/spring-ai-qianfan/src/test/resources/prompts/system-message.st create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml diff --git a/models/spring-ai-qianfan/README.md b/models/spring-ai-qianfan/README.md new file mode 100644 index 00000000000..4b895a70958 --- /dev/null +++ b/models/spring-ai-qianfan/README.md @@ -0,0 +1,3 @@ +[QianFan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/qianfan-chat.html) + +[QianFan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/qianfan-embeddings.html) \ No newline at end of file diff --git a/models/spring-ai-qianfan/pom.xml b/models/spring-ai-qianfan/pom.xml new file mode 100644 index 00000000000..8326c7f20fd --- /dev/null +++ b/models/spring-ai-qianfan/pom.xml @@ -0,0 +1,59 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-qianfan + jar + Spring AI QianFan + Baidu QianFan support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatClient.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatClient.java new file mode 100644 index 00000000000..f759d31a842 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatClient.java @@ -0,0 +1,199 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.ChatClient; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletion; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionChunk; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage.Role; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * {@link ChatClient} and {@link StreamingChatClient} implementation for + * {@literal QianFan} backed by {@link QianFanApi}. + * + * @author Geng Rong + * @see ChatClient + * @see StreamingChatClient + * @see QianFanApi + */ +public class QianFanChatClient implements ChatClient, StreamingChatClient { + + private static final Logger logger = LoggerFactory.getLogger(QianFanChatClient.class); + + /** + * The default options used for the chat completion requests. + */ + private final QianFanChatOptions defaultOptions; + + /** + * The retry template used to retry the QianFan API calls. + */ + public final RetryTemplate retryTemplate; + + /** + * Low-level access to the QianFan API. + */ + private final QianFanApi qianFanApi; + + /** + * Creates an instance of the QianFanChatClient. + * @param qianFanApi The QianFanApi instance to be used for interacting with the + * QianFan Chat API. + * @throws IllegalArgumentException if QianFanApi is null + */ + public QianFanChatClient(QianFanApi qianFanApi) { + this(qianFanApi, + QianFanChatOptions.builder().withModel(QianFanApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()); + } + + /** + * Initializes an instance of the QianFanChatClient. + * @param qianFanApi The QianFanApi instance to be used for interacting with the + * QianFan Chat API. + * @param options The QianFanChatOptions to configure the chat client. + */ + public QianFanChatClient(QianFanApi qianFanApi, QianFanChatOptions options) { + this(qianFanApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the QianFanChatClient. + * @param qianFanApi The QianFanApi instance to be used for interacting with the + * QianFan Chat API. + * @param options The QianFanChatOptions to configure the chat client. + * @param retryTemplate The retry template. + */ + public QianFanChatClient(QianFanApi qianFanApi, QianFanChatOptions options, RetryTemplate retryTemplate) { + Assert.notNull(qianFanApi, "QianFanApi must not be null"); + Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.qianFanApi = qianFanApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + + ChatCompletionRequest request = createRequest(prompt, false); + + return this.retryTemplate.execute(ctx -> { + + ResponseEntity completionEntity = this.doChatCompletion(request); + + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + // if (chatCompletion.baseResponse() != null && + // chatCompletion.baseResponse().statusCode() != 0) { + // throw new RuntimeException(chatCompletion.baseResponse().message()); + // } + + var generation = new Generation(chatCompletion.result(), + Map.of("id", chatCompletion.id(), "role", Role.ASSISTANT)); + return new ChatResponse(Collections.singletonList(generation)); + }); + } + + @Override + public Flux stream(Prompt prompt) { + var request = createRequest(prompt, true); + + return retryTemplate.execute(ctx -> { + var completionChunks = this.qianFanApi.chatCompletionStream(request); + + return completionChunks.map(this::toChatCompletion).map(chatCompletion -> { + String id = chatCompletion.id(); + var generation = new Generation(chatCompletion.result(), Map.of("id", id, "role", Role.ASSISTANT)); + return new ChatResponse(Collections.singletonList(generation)); + }); + }); + } + + /** + * Convert the ChatCompletionChunk into a ChatCompletion. + * @param chunk the ChatCompletionChunk to convert + * @return the ChatCompletion + */ + private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { + return new ChatCompletion(chunk.id(), chunk.object(), chunk.created(), chunk.result(), chunk.usage()); + } + + /** + * Accessible for testing. + */ + public ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + var chatCompletionMessages = prompt.getInstructions() + .stream() + .map(m -> new ChatCompletionMessage(m.getContent(), + ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) + .toList(); + var systemMessageList = chatCompletionMessages.stream().filter(msg -> msg.role() == Role.SYSTEM).toList(); + + if (systemMessageList.size() > 1) { + throw new IllegalArgumentException("Only one system message is allowed in the prompt"); + } + + var systemMessage = systemMessageList.isEmpty() ? null : systemMessageList.get(0).content(); + + var request = new ChatCompletionRequest(chatCompletionMessages, systemMessage, stream); + + if (this.defaultOptions != null) { + request = ModelOptionsUtils.merge(this.defaultOptions, request, ChatCompletionRequest.class); + } + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class, + QianFanChatOptions.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + return request; + } + + private ResponseEntity doChatCompletion(ChatCompletionRequest request) { + return this.qianFanApi.chatCompletionEntity(request); + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java new file mode 100644 index 00000000000..ecebdd18fc2 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java @@ -0,0 +1,292 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +import java.util.List; + +/** + * QianFanChatOptions represents the options for performing chat completion using the + * QianFan API. It provides methods to set and retrieve various options like model, + * frequency penalty, max tokens, etc. + * + * @author Geng Rong + * @see ChatOptions + */ +@JsonInclude(Include.NON_NULL) +public class QianFanChatOptions implements ChatOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Float frequencyPenalty; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Float presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + @NestedConfigurationProperty + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Float temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Float topP; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected QianFanChatOptions options; + + public Builder() { + this.options = new QianFanChatOptions(); + } + + public Builder(QianFanChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withResponseFormat(QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.options.responseFormat = responseFormat; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.options.topP = topP; + return this; + } + + public QianFanChatOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Float getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public QianFanApi.ChatCompletionRequest.ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + QianFanChatOptions other = (QianFanChatOptions) obj; + if (this.model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) + return false; + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + return false; + if (this.maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!this.maxTokens.equals(other.maxTokens)) + return false; + if (this.presencePenalty == null) { + if (other.presencePenalty != null) + return false; + } + else if (!this.presencePenalty.equals(other.presencePenalty)) + return false; + if (this.responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!this.responseFormat.equals(other.responseFormat)) + return false; + if (this.stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (this.temperature == null) { + if (other.temperature != null) + return false; + } + else if (!this.temperature.equals(other.temperature)) + return false; + if (this.topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + return true; + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingClient.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingClient.java new file mode 100644 index 00000000000..abeb331587e --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingClient.java @@ -0,0 +1,156 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.AbstractEmbeddingClient; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * QianFan Embedding Client implementation. + * + * @author Geng Rong + */ +public class QianFanEmbeddingClient extends AbstractEmbeddingClient { + + private static final Logger logger = LoggerFactory.getLogger(QianFanEmbeddingClient.class); + + private final QianFanEmbeddingOptions defaultOptions; + + private final RetryTemplate retryTemplate; + + private final QianFanApi qianFanApi; + + private final MetadataMode metadataMode; + + /** + * Constructor for the QianFanEmbeddingClient class. + * @param qianFanApi The QianFanApi instance to use for making API requests. + */ + public QianFanEmbeddingClient(QianFanApi qianFanApi) { + this(qianFanApi, MetadataMode.EMBED); + } + + /** + * Initializes a new instance of the QianFanEmbeddingClient class. + * @param qianFanApi The QianFanApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + */ + public QianFanEmbeddingClient(QianFanApi qianFanApi, MetadataMode metadataMode) { + this(qianFanApi, metadataMode, + QianFanEmbeddingOptions.builder().withModel(QianFanApi.DEFAULT_EMBEDDING_MODEL).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the QianFanEmbeddingClient class. + * @param qianFanApi The QianFanApi instance to use for making API requests. + * @param metadataMode The mode for generating metadata. + * @param qianFanEmbeddingOptions The options for QianFan embedding. + */ + public QianFanEmbeddingClient(QianFanApi qianFanApi, MetadataMode metadataMode, + QianFanEmbeddingOptions qianFanEmbeddingOptions) { + this(qianFanApi, metadataMode, qianFanEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the QianFanEmbeddingClient class. + * @param qianFanApi - The QianFanApi instance to use for making API requests. + * @param metadataMode - The mode for generating metadata. + * @param options - The options for QianFan embedding. + * @param retryTemplate - The RetryTemplate for retrying failed API requests. + */ + public QianFanEmbeddingClient(QianFanApi qianFanApi, MetadataMode metadataMode, QianFanEmbeddingOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(qianFanApi, "QianFanApi must not be null"); + Assert.notNull(metadataMode, "metadataMode must not be null"); + Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + + this.qianFanApi = qianFanApi; + this.metadataMode = metadataMode; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public List embed(Document document) { + Assert.notNull(document, "Document must not be null"); + return this.embed(document.getFormattedContent(this.metadataMode)); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + + return this.retryTemplate.execute(ctx -> { + QianFanApi.EmbeddingRequest apiRequest = (this.defaultOptions != null) + ? new QianFanApi.EmbeddingRequest(request.getInstructions(), this.defaultOptions.getModel(), + this.defaultOptions.getUser()) + : new QianFanApi.EmbeddingRequest(request.getInstructions()); + + if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) { + apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest, + QianFanApi.EmbeddingRequest.class); + } + + EmbeddingList apiEmbeddingResponse = this.qianFanApi.embeddings(apiRequest).getBody(); + + if (apiEmbeddingResponse == null) { + logger.warn("No embeddings returned for request: {}", request); + return new EmbeddingResponse(List.of()); + } + + if (apiEmbeddingResponse.errorNsg() != null) { + logger.error("Error message returned for request: {}", apiEmbeddingResponse.errorNsg()); + throw new RuntimeException("Embedding failed: error code:" + apiEmbeddingResponse.errorCode() + + ", message:" + apiEmbeddingResponse.errorNsg()); + } + + var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage()); + + List embeddings = apiEmbeddingResponse.data() + .stream() + .map(e -> new Embedding(e.embedding(), e.index())) + .toList(); + + return new EmbeddingResponse(embeddings, metadata); + }); + } + + private EmbeddingResponseMetadata generateResponseMetadata(String model, QianFanApi.Usage usage) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); + metadata.put("model", model); + metadata.put("prompt-tokens", usage.promptTokens()); + metadata.put("total-tokens", usage.totalTokens()); + return metadata; + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java new file mode 100644 index 00000000000..707af674bd5 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * This class represents the options for QianFan embedding. + * + * @author Geng Rong + */ +@JsonInclude(Include.NON_NULL) +public class QianFanEmbeddingOptions implements EmbeddingOptions { + + // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + + /** + * A unique identifier representing your end-user, which can help MoonshotAi to + * monitor and detect abuse. + */ + private @JsonProperty("user") String user; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected QianFanEmbeddingOptions options; + + public Builder() { + this.options = new QianFanEmbeddingOptions(); + } + + public Builder withModel(String model) { + this.options.setModel(model); + return this; + } + + public Builder withUser(String user) { + this.options.setUser(user); + return this; + } + + public QianFanEmbeddingOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java new file mode 100644 index 00000000000..762734404f1 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan.aot; + +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * The QianFanRuntimeHints class is responsible for registering runtime hints for QianFan + * API classes. + * + * @author Geng Rong + */ +public class QianFanRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(QianFanApi.class)) + hints.reflection().registerType(tr, mcs); + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/ApiUtils.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/ApiUtils.java new file mode 100644 index 00000000000..5e8c5e47744 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/ApiUtils.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan.api; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +import java.util.function.Consumer; + +/** + * The ApiUtils class provides utility methods for working with API requests and + * responses. + * + * @author Geng Rong + */ +public class ApiUtils { + + public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com/rpc/2.0/ai_custom"; + + public static Consumer getJsonContentHeaders() { + return (headers) -> headers.setContentType(MediaType.APPLICATION_JSON); + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java new file mode 100644 index 00000000000..00d001afebc --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java @@ -0,0 +1,550 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.qianfan.api.auth.QianFanAccessToken; +import org.springframework.ai.qianfan.api.auth.QianFanAuthenticator; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.List; +import java.util.function.Predicate; + +// @formatter:off +/** + * Single class implementation of the QianFan Chat Completion API and Embedding API. + * QianFan Docs + * + * @author Geng Rong + */ +public class QianFanApi { + + public static final String DEFAULT_CHAT_MODEL = ChatModel.ERNIE_Speed_8K.getValue(); + public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.BGE_LARGE_ZH.getValue(); + private static final Predicate SSE_DONE_PREDICATE = ChatCompletionChunk::end; + + private final QianFanAuthenticator authenticator; + + private final RestClient restClient; + + private final WebClient webClient; + + private QianFanAccessToken token; + + /** + * Create a new chat completion api with default base URL. + * + * @param apiKey QianFan api key. + * @param secretKey QianFan secret key. + */ + public QianFanApi(String apiKey, String secretKey) { + this(ApiUtils.DEFAULT_BASE_URL, apiKey, secretKey); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param apiKey QianFan api key. + * @param secretKey QianFan secret key. + */ + public QianFanApi(String baseUrl, String apiKey, String secretKey) { + this(baseUrl, apiKey, secretKey, RestClient.builder()); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param apiKey QianFan api key. + * @param secretKey QianFan secret key. + * @param restClientBuilder RestClient builder. + */ + public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Builder restClientBuilder) { + this(baseUrl, apiKey, secretKey, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param apiKey QianFan api key. + * @param secretKey QianFan secret key. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + this.restClient = restClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders()) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = WebClient.builder() + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders()) + .build(); + + this.authenticator = QianFanAuthenticator.builder() + .apiKey(apiKey) + .secretKey(secretKey) + .build(); + } + + /** + * QianFan Chat Completion Models: + * QianFan Model. + */ + public enum ChatModel { + ERNIE_4_0_8K("completions_pro"), + ERNIE_4_0_8K_Preview("ernie-4.0-8k-preview"), + ERNIE_4_0_8K_Preview_0518("completions_adv_pro"), + ERNIE_4_0_8K_0329("ernie-4.0-8k-0329"), + ERNIE_4_0_8K_0104("ernie-4.0-8k-0104"), + ERNIE_3_5_8K("completions"), + ERNIE_3_5_128K("ernie-3.5-128k"), + ERNIE_3_5_8K_Preview("ernie-3.5-8k-preview"), + ERNIE_3_5_8K_0205("ernie-3.5-8k-0205"), + ERNIE_3_5_8K_0329("ernie-3.5-8k-0329"), + ERNIE_3_5_8K_1222("ernie-3.5-8k-1222"), + ERNIE_3_5_4K_0205("ernie-3.5-4k-0205"), + + ERNIE_Lite_8K_0922("eb-instant"), + ERNIE_Lite_8K_0308("ernie-lite-8k"), + ERNIE_Speed_8K("ernie_speed"), + ERNIE_Speed_128K("ernie-speed-128k"), + ERNIE_Tiny_8K("ernie-tiny-8k"), + ERNIE_FUNC_8K("ernie-func-8k"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** + * Creates a model response for the given chat conversation. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + * @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + * appear in the text so far, increasing the model's likelihood to talk about new topics. + * @param responseFormat An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + * @param stop Up to 4 sequences where the API will stop generating further tokens. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as + * they become available, with the stream terminated by a data: [DONE] message. + * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + * @param topP An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionRequest ( + @JsonProperty("messages") List messages, + @JsonProperty("system") String system, + @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Float frequencyPenalty, + @JsonProperty("max_output_tokens") Integer maxTokens, + @JsonProperty("presence_penalty") Float presencePenalty, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Float temperature, + @JsonProperty("top_p") Float topP) { + + /** + * Shortcut constructor for a chat completion request with the given messages and model. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + */ + public ChatCompletionRequest(List messages, String system, String model, Float temperature) { + this(messages, system, model, null,null, + null, null, null, false, temperature, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model and control for streaming. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. + * @param temperature What sampling temperature to use, between 0 and 1. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String system, String model, Float temperature, boolean stream) { + this(messages, system, model, null,null, + null, null, null, stream, temperature, null); + } + + + /** + * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. + * Streaming is set to false, temperature to 0.8 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String system, Boolean stream) { + this(messages, system, DEFAULT_CHAT_MODEL, null,null, + null, null, null, stream, 0.8F, null); + } + + /** + * An object specifying the format that the model must output. + * @param type Must be one of 'text' or 'json_object'. + */ + @JsonInclude(Include.NON_NULL) + public record ResponseFormat( + @JsonProperty("type") String type) { + } + } + + /** + * Message comprising the conversation. + * + * @param rawContent The contents of the message. Can be a {@link String}. + * The response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} types. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionMessage( + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role) { + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * The role of the author of this message. + */ + public enum Role { + /** + * System message. + */ + @JsonProperty("system") SYSTEM, + /** + * User message. + */ + @JsonProperty("user") USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") ASSISTANT + } + } + + /** + * Represents a chat completion response returned by model, based on the provided input. + * + * @param id A unique identifier for the chat completion. + * @param result Result of chat completion message. + * @param created The Unix timestamp (in seconds) of when the chat completion was created. + * used in conjunction with the seed request parameter to understand when backend changes have been made that might + * impact determinism. + * @param object The object type, which is always chat.completion. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletion( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("result") String result, + @JsonProperty("usage") Usage usage) { + } + + /** + * Usage statistics for the completion request. + * + * @param promptTokens Number of tokens in the prompt. + * @param totalTokens Total number of tokens used in the request (prompt + completion). + */ + @JsonInclude(Include.NON_NULL) + public record Usage( + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + + } + + /** + * Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param object The object type, which is always 'chat.completion.chunk'. + * @param created The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same + * timestamp. + * @param result Result of chat completion message. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionChunk( + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("result") String result, + @JsonProperty("is_end") Boolean end, + + @JsonProperty("usage") Usage usage + ) { + } + + private String getAccessToken() { + if(this.token == null || this.token.needsRefresh()) { + this.token = this.authenticator.requestToken(); + } + return this.token.getAccessToken(); + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * @param chatRequest The chat completion request. Must have the stream property set + * to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + + return this.webClient.post() + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(ChatCompletionChunk.class) + .takeUntil(SSE_DONE_PREDICATE); + } + + /** + * QianFan Embeddings Models: + * Embeddings. + */ + public enum EmbeddingModel { + + /** + * DIMENSION: 384 + */ + EMBEDDING_V1("embedding-v1"), + + /** + * DIMENSION: 1024 + */ + BGE_LARGE_ZH("bge_large_zh"), + + /** + * DIMENSION: 1024 + */ + BGE_LARGE_EN("bge_large_en"), + + /** + * DIMENSION: 1024 + */ + TAO_8K("tao_8k"); + + public final String value; + + EmbeddingModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** + * Creates an embedding vector representing the input text. + * + * @param texts Input text to embed, encoded as a string or array of tokens. + * @param user A unique identifier representing your end-user, which can help MoonshotAi to + * monitor and detect abuse. + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingRequest( + @JsonProperty("input") List texts, + @JsonProperty("model") String model, + @JsonProperty("user") String user + ) { + + + /** + * Create an embedding request with the given input. + * Embedding model is set to 'tao_8k'. + * @param text Input text to embed. + */ + public EmbeddingRequest(String text) { + this(List.of(text), DEFAULT_EMBEDDING_MODEL, null); + } + + + /** + * Create an embedding request with the given input. + * @param text Input text to embed. + * @param model ID of the model to use. + * @param userId A unique identifier representing your end-user, which can help MoonshotAi to + * monitor and detect abuse. + */ + public EmbeddingRequest(String text,String model,String userId) { + this(List.of(text), model, userId); + } + + /** + * Create an embedding request with the given input. + * Embedding model is set to 'tao_8k'. + * @param texts Input text to embed. + */ + public EmbeddingRequest(List texts) { + this(texts, DEFAULT_EMBEDDING_MODEL, null); + } + + /** + * Create an embedding request with the given input. + * @param texts Input text to embed. + * @param model ID of the model to use. + */ + public EmbeddingRequest(List texts, String model) { + this(texts, model, null); + } + } + + /** + * Represents an embedding vector returned by embedding endpoint. + * + * @param index The index of the embedding in the list of embeddings. + * @param embedding The embedding vector, which is a list of floats. The length of + * vector depends on the model. + * @param object The object type, which is always 'embedding'. + */ + @JsonInclude(Include.NON_NULL) + public record Embedding( + // @formatter:off + @JsonProperty("index") Integer index, + @JsonProperty("embedding") List embedding, + @JsonProperty("object") String object) { + // @formatter:on + + /** + * Create an embedding with the given index, embedding and object type set to + * 'embedding'. + * @param index The index of the embedding in the list of embeddings. + * @param embedding The embedding vector, which is a list of floats. The length of + * vector depends on the model. + */ + public Embedding(Integer index, List embedding) { + this(index, embedding, "embedding"); + } + } + + /** + * List of multiple embedding responses. + * + * @param object Must have value "embedding_list". + * @param data List of entities. + * @param model ID of the model to use. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(Include.NON_NULL) + public record EmbeddingList( + // @formatter:off + @JsonProperty("object") String object, + @JsonProperty("data") List data, + @JsonProperty("model") String model, + @JsonProperty("error_code") String errorCode, + @JsonProperty("error_msg") String errorNsg, + @JsonProperty("usage") Usage usage) { + // @formatter:on + } + + /** + * Creates an embedding vector representing the input text or token array. + * @param embeddingRequest The embedding request. + * @return Returns list of {@link Embedding} wrapped in {@link EmbeddingList}. + */ + public ResponseEntity embeddings(EmbeddingRequest embeddingRequest) { + + Assert.notNull(embeddingRequest, "The request body can not be null."); + + // Input text to embed, encoded as a string or array of tokens. To embed multiple + // inputs in a single + // request, pass an array of strings or array of token arrays. + Assert.notNull(embeddingRequest.texts(), "The input can not be null."); + + // The input must not an empty string, and any array must be 16 dimensions or + // less. + Assert.isTrue(!CollectionUtils.isEmpty(embeddingRequest.texts()), "The input list can not be empty."); + Assert.isTrue(embeddingRequest.texts().size() <= 16, "The list must be 16 dimensions or less"); + + return this.restClient.post() + .uri("/v1/wenxinworkshop/embeddings/{model}?access_token={token}", embeddingRequest.model, getAccessToken()) + .body(embeddingRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference<>() { + }); + } + +} +// @formatter:on diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java new file mode 100644 index 00000000000..af2f9831c06 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/AccessTokenResponse.java @@ -0,0 +1,12 @@ +package org.springframework.ai.qianfan.api.auth; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * @author Geng Rong + */ +public record AccessTokenResponse(@JsonProperty("access_token") String accessToken, + @JsonProperty("refresh_token") String refreshToken, @JsonProperty("expires_in") Long expiresIn, + @JsonProperty("session_key") String sessionKey, @JsonProperty("session_secret") String sessionSecret, + @JsonProperty("error") String error, @JsonProperty("error_description") String errorDescription, String scope) { +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java new file mode 100644 index 00000000000..b1ec74f9273 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java @@ -0,0 +1,70 @@ +package org.springframework.ai.qianfan.api.auth; + +/** + * @author Geng Rong + */ +public class QianFanAccessToken { + + private static final Double FRACTION_OF_TIME_TO_LIVE = 0.8D; + + private final String accessToken; + + private final String refreshToken; + + private final Long expiresIn; + + private final String sessionKey; + + private final String sessionSecret; + + private final String scope; + + private final Long refreshTime; + + public QianFanAccessToken(AccessTokenResponse accessTokenResponse) { + this.accessToken = accessTokenResponse.accessToken(); + this.refreshToken = accessTokenResponse.refreshToken(); + this.expiresIn = accessTokenResponse.expiresIn(); + this.sessionKey = accessTokenResponse.sessionKey(); + this.sessionSecret = accessTokenResponse.sessionSecret(); + this.scope = accessTokenResponse.scope(); + this.refreshTime = getCurrentTimeInSeconds() + (long) ((double) expiresIn * FRACTION_OF_TIME_TO_LIVE); + } + + public String getAccessToken() { + return accessToken; + } + + public String getRefreshToken() { + return refreshToken; + } + + public Long getExpiresIn() { + return expiresIn; + } + + public String getSessionKey() { + return sessionKey; + } + + public String getSessionSecret() { + return sessionSecret; + } + + public Long getRefreshTime() { + return refreshTime; + } + + public String getScope() { + return scope; + } + + public synchronized boolean needsRefresh() { + return getCurrentTimeInSeconds() >= this.refreshTime; + } + + private long getCurrentTimeInSeconds() { + return System.currentTimeMillis() / 1000L; + } + +} \ No newline at end of file diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java new file mode 100644 index 00000000000..9567dea7036 --- /dev/null +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAuthenticator.java @@ -0,0 +1,71 @@ +package org.springframework.ai.qianfan.api.auth; + +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestClient; + +/** + * @author Geng Rong + */ +public class QianFanAuthenticator { + + private static final String DEFAULT_AUTH_URL = "https://aip.baidubce.com"; + + private static final String OPERATION_PATH = "/oauth/2.0/token?client_id={clientId}&client_secret={clientSecret}&grant_type=client_credentials"; + + private final RestClient restClient; + + private final String apiKey; + + private final String secretKey; + + public QianFanAuthenticator(String authUrl, String apiKey, String secretKey) { + this.apiKey = apiKey; + this.secretKey = secretKey; + this.restClient = RestClient.builder().baseUrl(authUrl).build(); + } + + public QianFanAccessToken requestToken() { + ResponseEntity tokenResponseEntity = this.restClient.get() + .uri(OPERATION_PATH, apiKey, secretKey) + .retrieve() + .toEntity(AccessTokenResponse.class); + AccessTokenResponse tokenResponse = tokenResponseEntity.getBody(); + + if (tokenResponse == null) { + throw new IllegalArgumentException("Failed to get access token, response is null"); + } + + if (tokenResponse.error() != null) { + throw new IllegalArgumentException("Failed to get access token, error: " + tokenResponse.error() + + ", error_description: " + tokenResponse.errorDescription()); + } + return new QianFanAccessToken(tokenResponse); + } + + public static class Builder { + + private String apiKey; + + private String secretKey; + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder secretKey(String secretKey) { + this.secretKey = secretKey; + return this; + } + + public QianFanAuthenticator build() { + return new QianFanAuthenticator(DEFAULT_AUTH_URL, apiKey, secretKey); + } + + } + + public static Builder builder() { + return new Builder(); + } + +} diff --git a/models/spring-ai-qianfan/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-qianfan/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..4db5c1dc49c --- /dev/null +++ b/models/spring-ai-qianfan/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.qianfan.aot.QianFanRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java new file mode 100644 index 00000000000..7a708c79e68 --- /dev/null +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/ChatCompletionRequestTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.qianfan.api.QianFanApi; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +public class ChatCompletionRequestTests { + + @Test + public void createRequestWithChatOptions() { + + var client = new QianFanChatClient(new QianFanApi("TEST", "TEST"), + QianFanChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.temperature()).isEqualTo(66.6f); + + request = client.createRequest(new Prompt("Test message content", + QianFanChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + assertThat(request.temperature()).isEqualTo(99.9f); + } + +} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java new file mode 100644 index 00000000000..c272f67643e --- /dev/null +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanTestConfiguration.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan; + +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Geng Rong + */ +@SpringBootConfiguration +public class QianFanTestConfiguration { + + @Bean + public QianFanApi qianFanApi() { + return new QianFanApi(getApiKey(), getSecretKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("QIANFAN_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name QIANFAN_API_KEY"); + } + return apiKey; + } + + private String getSecretKey() { + String apiKey = System.getenv("QIANFAN_SECRET_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide a secret key. Put it in an environment variable under the name QIANFAN_SECRET_KEY"); + } + return apiKey; + } + + @Bean + public QianFanChatClient qianFanChatClient(QianFanApi api) { + return new QianFanChatClient(api); + } + + @Bean + public EmbeddingClient qianFanEmbeddingClient(QianFanApi api) { + return new QianFanEmbeddingClient(api); + } + +} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java new file mode 100644 index 00000000000..f3d07f554d0 --- /dev/null +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.ResourceUtils; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletion; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionChunk; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage.Role; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest; +import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList; +import org.springframework.http.ResponseEntity; +import org.stringtemplate.v4.ST; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Objects; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), + @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) +public class QianFanApiIT { + + QianFanApi qianFanApi = new QianFanApi(System.getenv("QIANFAN_API_KEY"), System.getenv("QIANFAN_SECRET_KEY")); + + @Test + void chatCompletionEntity() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + ResponseEntity response = qianFanApi.chatCompletionEntity(new ChatCompletionRequest( + List.of(chatCompletionMessage), buildSystemMessage(), "ernie_speed", 0.7f, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + Flux response = qianFanApi.chatCompletionStream(new ChatCompletionRequest( + List.of(chatCompletionMessage), buildSystemMessage(), "ernie_speed", 0.7f, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + + @Test + void embeddings() { + ResponseEntity response = qianFanApi.embeddings(new QianFanApi.EmbeddingRequest("Hello world")); + + assertThat(response).isNotNull(); + assertThat(Objects.requireNonNull(response.getBody()).data()).hasSize(1); + assertThat(response.getBody().data().get(0).embedding()).hasSize(1024); + } + + private String buildSystemMessage() { + String systemMessageTemplate = ResourceUtils.getText("classpath:/prompts/system-message.st"); + ST st = new ST(systemMessageTemplate, '{', '}'); + return st.add("name", "QianFan").add("voice", "pirate").render(); + } + +} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java new file mode 100644 index 00000000000..0fd7372c65c --- /dev/null +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java @@ -0,0 +1,172 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.qianfan.api; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.qianfan.QianFanChatClient; +import org.springframework.ai.qianfan.QianFanChatOptions; +import org.springframework.ai.qianfan.QianFanEmbeddingClient; +import org.springframework.ai.qianfan.QianFanEmbeddingOptions; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletion; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionChunk; +import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest; +import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList; +import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingRequest; +import org.springframework.ai.qianfan.api.QianFanApi.Usage; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * @author Geng Rong + */ +@ExtendWith(MockitoExtension.class) +public class QianFanRetryTests { + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private @Mock QianFanApi qianFanApi; + + private QianFanChatClient chatClient; + + private QianFanEmbeddingClient embeddingClient; + + @BeforeEach + public void beforeEach() { + RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + chatClient = new QianFanChatClient(qianFanApi, QianFanChatOptions.builder().build(), retryTemplate); + embeddingClient = new QianFanEmbeddingClient(qianFanApi, MetadataMode.EMBED, + QianFanEmbeddingOptions.builder().build(), retryTemplate); + } + + @Test + public void qianFanChatTransientError() { + ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 666L, "Response", + new Usage(10, 10)); + + when(qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = chatClient.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void qianFanChatNonTransientError() { + when(qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text"))); + } + + @Test + public void qianFanChatStreamTransientError() { + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion", 666L, "Response", + true, null); + + when(qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(Flux.just(expectedChatCompletion)); + + var result = chatClient.stream(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(Objects.requireNonNull(result.collectList().block()).get(0).getResult().getOutput().getContent()) + .isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void qianFanChatStreamNonTransientError() { + when(qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text"))); + } + + @Test + public void qianFanEmbeddingTransientError() { + QianFanApi.Embedding embedding = new QianFanApi.Embedding(1, List.of(9.9, 8.8)); + EmbeddingList expectedEmbeddings = new EmbeddingList("embedding_list", List.of(embedding), "model", null, null, + new Usage(10, 10)); + + when(qianFanApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + + var result = embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8)); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void qianFanEmbeddingNonTransientError() { + when(qianFanApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> embeddingClient.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + } + +} diff --git a/models/spring-ai-qianfan/src/test/resources/prompts/system-message.st b/models/spring-ai-qianfan/src/test/resources/prompts/system-message.st new file mode 100644 index 00000000000..579febd8d9b --- /dev/null +++ b/models/spring-ai-qianfan/src/test/resources/prompts/system-message.st @@ -0,0 +1,3 @@ +You are an AI assistant that helps people find information. +Your name is {name}. +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 5f583cbb364..2037dadc8e6 100644 --- a/pom.xml +++ b/pom.xml @@ -62,6 +62,7 @@ models/spring-ai-ollama models/spring-ai-openai models/spring-ai-postgresml + models/spring-ai-qianfan models/spring-ai-stability-ai models/spring-ai-transformers models/spring-ai-vertex-ai-gemini @@ -77,6 +78,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-ollama spring-ai-spring-boot-starters/spring-ai-starter-openai spring-ai-spring-boot-starters/spring-ai-starter-postgresml-embedding + spring-ai-spring-boot-starters/spring-ai-starter-qianfan spring-ai-spring-boot-starters/spring-ai-starter-stability-ai spring-ai-spring-boot-starters/spring-ai-starter-transformers spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 84615ecc2d4..f6294f49a6d 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -130,6 +130,11 @@ ${project.version} + + org.springframework.ai + spring-ai-qianfan + ${project.version} + @@ -408,6 +413,12 @@ spring-ai-zhipuai-spring-boot-starter ${project.version} + + + org.springframework.ai + spring-ai-qianfan-spring-boot-starter + ${project.version} + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 62d2a59aec7..dfe8eb40058 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -29,6 +29,7 @@ *** xref:api/chat/watsonx-ai-chat.adoc[Watsonx.AI] *** xref:api/chat/minimax-chat.adoc[MiniMax] **** xref:api/chat/functions/minimax-chat-functions.adoc[Function Calling] +*** xref:api/chat/qianfan-chat.adoc[QianFan] ** xref:api/embeddings.adoc[] *** xref:api/embeddings/openai-embeddings.adoc[OpenAI] *** xref:api/embeddings/ollama-embeddings.adoc[Ollama] @@ -42,6 +43,7 @@ *** xref:api/embeddings/mistralai-embeddings.adoc[Mistral AI] *** xref:api/embeddings/minimax-embeddings.adoc[MiniMax] *** xref:api/embeddings/zhipuai-embeddings.adoc[ZhiPu AI] +*** xref:api/embeddings/qianfan-embeddings.adoc[QianFan] ** xref:api/imageclient.adoc[] *** xref:api/image/openai-image.adoc[OpenAI] *** xref:api/image/stabilityai-image.adoc[Stability] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc new file mode 100644 index 00000000000..d0911d2d4c7 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/qianfan-chat.adoc @@ -0,0 +1,254 @@ += QianFan Chat + +Spring AI supports the various AI language models from QianFan. You can interact with QianFan language models and create a multilingual conversational assistant based on QianFan models. + +== Prerequisites + +You will need to create an API with QianFan to access QianFan language models. + +Create an account at https://login.bce.baidu.com/new-reg[QianFan registration page] and generate the token on the https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.qianfan.api-key` and `spring.ai.qianfan.secret-key`. +you should set to the value of the `API Key` and `Secret Key` obtained from https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_QIANFAN_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the QianFan Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-qianfan-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-qianfan-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the QianFan Chat client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.qianfan` is used as the property prefix that lets you connect to QianFan. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.qianfan.base-url | The URL to connect to | https://api.qianfan.chat +| spring.ai.qianfan.api-key | The API Key | - +| spring.ai.qianfan.secret-key | The Secret Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.qianfan.chat` is the property prefix that lets you configure the chat client implementation for QianFan. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.qianfan.chat.enabled | Enable QianFan chat client. | true +| spring.ai.qianfan.chat.base-url | Optional overrides the spring.ai.qianfan.base-url to provide chat specific url | https://api.qianfan.chat +| spring.ai.qianfan.chat.api-key | Optional overrides the spring.ai.qianfan.api-key to provide chat specific api-key | - +| spring.ai.qianfan.chat.secret-key | Optional overrides the spring.ai.qianfan.secret-key to provide chat specific secret-key | - +| spring.ai.qianfan.chat.options.model | This is the QianFan Chat model to use | `abab5.5-chat` (the `abab5.5s-chat`, `abab5.5-chat`, and `abab6-chat` point to the latest model versions) +| spring.ai.qianfan.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.qianfan.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.7 +| spring.ai.qianfan.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | 1.0 +| spring.ai.qianfan.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | 0.0f +| spring.ai.qianfan.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.qianfan.chat.options.stop | The model will stop generating characters specified by stop, and currently only supports a single stop word in the format of ["stop_word1"] | - +|==== + +NOTE: You can override the common `spring.ai.qianfan.base-url`, `spring.ai.qianfan.chat.api-key` and `spring.ai.qianfan.chat.secret-key` for the `ChatClient` implementations. +The `spring.ai.qianfan.chat.base-url`, `spring.ai.qianfan.chat.api-key` and `spring.ai.qianfan.chat.secret-key` properties if set take precedence over the common properties. +This is useful if you want to use different QianFan accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.qianfan.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java[QianFanChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `QianFanChatClient(api, options)` constructor or the `spring.ai.qianfan.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatClient.call( + new Prompt( + "Generate the names of 5 famous pirates.", + QianFanChatOptions.builder() + .withModel(QianFanApi.ChatModel.ERNIE_Speed_8K.getValue()) + .withTemperature(0.5f) + .build() + )); +---- + +TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java[QianFanChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-qianfan-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the QianFan Chat client: + +[source,application.properties] +---- +spring.ai.qianfan.api-key=YOUR_API_KEY +spring.ai.qianfan.secret-key=YOUR_SECRET_KEY +spring.ai.qianfan.chat.options.model=ernie_speed +spring.ai.qianfan.chat.options.temperature=0.7 +---- + +TIP: replace the `api-key` and `secret-key` with your QianFan credentials. + +This will create a `QianFanChatClient` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat client for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final QianFanChatClient chatClient; + + @Autowired + public ChatController(QianFanChatClient chatClient) { + this.chatClient = chatClient; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatClient.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return chatClient.stream(prompt); + } +} +---- + +== Manual Configuration + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatClient.java[QianFanChatClient] implements the `ChatClient` and `StreamingChatClient` and uses the <> to connect to the QianFan service. + +Add the `spring-ai-qianfan` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-qianfan + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-qianfan' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `QianFanChatClient` and use it for text generations: + +[source,java] +---- +var qianFanApi = new QianFanApi(System.getenv("QIANFAN_API_KEY"), System.getenv("QIANFAN_SECRET_KEY")); + +var chatClient = new QianFanChatClient(qianFanApi, QianFanChatOptions.builder() + .withModel(QianFanApi.ChatModel.ERNIE_Speed_8K.getValue()) + .withTemperature(0.4f) + .withMaxTokens(200) + .build()); + +ChatResponse response = chatClient.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux streamResponse = chatClient.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +The `QianFanChatOptions` provides the configuration information for the chat requests. +The `QianFanChatOptions.Builder` is fluent options builder. + +=== Low-level QianFanApi Client [[low-level-api]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java[QianFanApi] provides is lightweight Java client for link:https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2[QianFan API]. + +Here is a simple snippet how to use the api programmatically: + +[source,java] +---- +String systemMessage = "Your name is QianWen"; + +QianFanApi qianFanApi = + new QianFanApi(System.getenv("QIANFAN_API_KEY"), System.getenv("QIANFAN_SECRET_KEY")); + +ChatCompletionMessage chatCompletionMessage = + new ChatCompletionMessage("Hello world", Role.USER); + +// Sync request +ResponseEntity response = qianFanApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7f, false)); + +// Streaming request +Flux streamResponse = qianFanApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), systemMessage, QianFanApi.ChatModel.ERNIE_Speed_8K.getValue(), 0.7f, true)); +---- + +Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java[QianFanApi.java]'s JavaDoc for further information. + +==== QianFanApi Samples +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java[QianFanApiIT.java] test provides some general examples how to use the lightweight library. \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc new file mode 100644 index 00000000000..e626b061953 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/qianfan-embeddings.adoc @@ -0,0 +1,202 @@ += QianFan Chat + +Spring AI supports the various AI language models from QianFan. You can interact with QianFan language models and create a multilingual conversational assistant based on QianFan models. + +== Prerequisites + +You will need to create an API with QianFan to access QianFan language models. + +Create an account at https://login.bce.baidu.com/new-reg[QianFan registration page] and generate the token on the https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.qianfan.api-key` and `spring.ai.qianfan.secret-key`. +you should set to the value of the `API Key` and `Secret Key` obtained from https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_QIANFAN_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Azure QianFan Embedding Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-qianfan-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-qianfan-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Embedding Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the QianFan Embedding client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.qianfan` is used as the property prefix that lets you connect to QianFan. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.qianfan.base-url | The URL to connect to | https://aip.baidubce.com/rpc/2.0/ai_custom +| spring.ai.qianfan.api-key | The API Key | - +| spring.ai.qianfan.secret-key | The Secret Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.qianfan.embedding` is property prefix that configures the `EmbeddingClient` implementation for QianFan. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.qianfan.embedding.enabled | Enable QianFan embedding client. | true +| spring.ai.qianfan.embedding.base-url | Optional overrides the spring.ai.qianfan.base-url to provide embedding specific url | - +| spring.ai.qianfan.embedding.api-key | Optional overrides the spring.ai.qianfan.api-key to provide embedding specific api-key | - +| spring.ai.qianfan.embedding.secret-key | Optional overrides the spring.ai.qianfan.secret-key to provide embedding specific secret-key | - +| spring.ai.qianfan.embedding.options.model | The model to use | bge_large_zh +|==== + +NOTE: You can override the common `spring.ai.qianfan.base-url`, `spring.ai.qianfan.embedding.api-key` and `spring.ai.qianfan.embedding.secret-key` for the `ChatClient` and `EmbeddingClient` implementations. +The `spring.ai.qianfan.embedding.base-url`, `spring.ai.qianfan.embedding.api-key` and `spring.ai.qianfan.embedding.secret-key` properties if set take precedence over the common properties. +Similarly, the `spring.ai.qianfan.embedding.base-url`, `spring.ai.qianfan.embedding.api-key` and `spring.ai.qianfan.embedding.secret-key` properties if set take precedence over the common properties. +This is useful if you want to use different QianFan accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.qianfan.embedding.options` can be overridden at runtime by adding a request specific <> to the `EmbeddingRequest` call. + +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingOptions.java[QianFanEmbeddingOptions.java] provides the QianFan configurations, such as the model to use and etc. + +The default options can be configured using the `spring.ai.qianfan.embedding.options` properties as well. + +At start-time use the `QianFanEmbeddingClient` constructor to set the default options used for all embedding requests. +At run-time you can override the default options, using a `QianFanEmbeddingOptions` instance as part of your `EmbeddingRequest`. + +For example to override the default model name for a specific request: + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingClient.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + QianFanEmbeddingOptions.builder() + .withModel("Different-Embedding-Model-Deployment-Name") + .build())); +---- + +== Sample Controller + +This will create a `EmbeddingClient` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `EmbeddingClient` implementation. + +[source,application.properties] +---- +spring.ai.qianfan.api-key=YOUR_API_KEY +spring.ai.qianfan.secret-key=YOUR_SECRET_KEY +spring.ai.qianfan.embedding.options.model=tao_8k +---- + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingClient embeddingClient; + + @Autowired + public EmbeddingController(EmbeddingClient embeddingClient) { + this.embeddingClient = embeddingClient; + } + + @GetMapping("/ai/embedding") + public Map embed(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + EmbeddingResponse embeddingResponse = this.embeddingClient.embedForResponse(List.of(message)); + return Map.of("embedding", embeddingResponse); + } +} +---- + +== Manual Configuration + +If you are not using Spring Boot, you can manually configure the QianFan Embedding Client. +For this add the `spring-ai-qianfan` dependency to your project's Maven `pom.xml` file: +[source, xml] +---- + + org.springframework.ai + spring-ai-qianfan + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-qianfan' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +NOTE: The `spring-ai-qianfan` dependency provides access also to the `QianFanChatClient`. +For more information about the `QianFanChatClient` refer to the link:../chat/qianfan-chat.html[QianFan Chat Client] section. + +Next, create an `QianFanEmbeddingClient` instance and use it to compute the similarity between two input texts: + +[source,java] +---- +var qianFanApi = new QianFanApi(System.getenv("MINIMAX_API_KEY"), System.getenv("QIANFAN_SECRET_KEY")); + +var embeddingClient = new QianFanEmbeddingClient(qianFanApi) + .withDefaultOptions(QianFanChatOptions.build() + .withModel("bge_large_en") + .build()); + +EmbeddingResponse embeddingResponse = embeddingClient + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); +---- + +The `QianFanEmbeddingOptions` provides the configuration information for the embedding requests. +The options class offers a `builder()` for easy options creation. + + diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 84bd54a188a..fcff112bd20 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -281,6 +281,13 @@ true + + org.springframework.ai + spring-ai-qianfan + ${project.parent.version} + true + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java new file mode 100644 index 00000000000..4bf8f0f6a7b --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfiguration.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.qianfan.QianFanChatClient; +import org.springframework.ai.qianfan.QianFanEmbeddingClient; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * @author Geng Rong + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(QianFanApi.class) +@EnableConfigurationProperties({ QianFanConnectionProperties.class, QianFanChatProperties.class, + QianFanEmbeddingProperties.class }) +public class QianFanAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = QianFanChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public QianFanChatClient qianFanChatClient(QianFanConnectionProperties commonProperties, + QianFanChatProperties chatProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + + var qianFanApi = qianFanApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), + chatProperties.getApiKey(), commonProperties.getApiKey(), chatProperties.getSecretKey(), + commonProperties.getSecretKey(), restClientBuilder, responseErrorHandler); + + return new QianFanChatClient(qianFanApi, chatProperties.getOptions(), retryTemplate); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = QianFanEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public QianFanEmbeddingClient qianFanEmbeddingClient(QianFanConnectionProperties commonProperties, + QianFanEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { + + var qianFanApi = qianFanApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), + embeddingProperties.getApiKey(), commonProperties.getApiKey(), embeddingProperties.getSecretKey(), + commonProperties.getSecretKey(), restClientBuilder, responseErrorHandler); + + return new QianFanEmbeddingClient(qianFanApi, embeddingProperties.getMetadataMode(), + embeddingProperties.getOptions(), retryTemplate); + } + + private QianFanApi qianFanApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, + String secretKey, String commonSecretKey, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { + + String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; + Assert.hasText(resolvedBaseUrl, "QianFan base URL must be set"); + + String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; + Assert.hasText(resolvedApiKey, "QianFan API key must be set"); + + String resolvedSecretKey = StringUtils.hasText(secretKey) ? secretKey : commonSecretKey; + Assert.hasText(resolvedSecretKey, "QianFan Secret key must be set"); + + return new QianFanApi(resolvedBaseUrl, resolvedApiKey, resolvedSecretKey, restClientBuilder, + responseErrorHandler); + } + + @Bean + @ConditionalOnMissingBean + public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) { + FunctionCallbackContext manager = new FunctionCallbackContext(); + manager.setApplicationContext(context); + return manager; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java new file mode 100644 index 00000000000..31e631a5009 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanChatProperties.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +import org.springframework.ai.qianfan.QianFanChatOptions; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(QianFanChatProperties.CONFIG_PREFIX) +public class QianFanChatProperties extends QianFanParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.qianfan.chat"; + + public static final String DEFAULT_CHAT_MODEL = QianFanApi.ChatModel.ERNIE_Speed_8K.value; + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + /** + * Enable QianFan chat client. + */ + private boolean enabled = true; + + @NestedConfigurationProperty + private QianFanChatOptions options = QianFanChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); + + public QianFanChatOptions getOptions() { + return options; + } + + public void setOptions(QianFanChatOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java new file mode 100644 index 00000000000..0a358844839 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanConnectionProperties.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +import org.springframework.ai.qianfan.api.ApiUtils; +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(QianFanConnectionProperties.CONFIG_PREFIX) +public class QianFanConnectionProperties extends QianFanParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.qianfan"; + + public static final String DEFAULT_BASE_URL = ApiUtils.DEFAULT_BASE_URL; + + public QianFanConnectionProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java new file mode 100644 index 00000000000..2a235110097 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanEmbeddingProperties.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.qianfan.QianFanEmbeddingOptions; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(QianFanEmbeddingProperties.CONFIG_PREFIX) +public class QianFanEmbeddingProperties extends QianFanParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.qianfan.embedding"; + + /** + * Enable QianFan embedding client. + */ + private boolean enabled = true; + + private MetadataMode metadataMode = MetadataMode.EMBED; + + @NestedConfigurationProperty + private QianFanEmbeddingOptions options = QianFanEmbeddingOptions.builder() + .withModel(QianFanApi.DEFAULT_EMBEDDING_MODEL) + .build(); + + public QianFanEmbeddingOptions getOptions() { + return this.options; + } + + public void setOptions(QianFanEmbeddingOptions options) { + this.options = options; + } + + public MetadataMode getMetadataMode() { + return this.metadataMode; + } + + public void setMetadataMode(MetadataMode metadataMode) { + this.metadataMode = metadataMode; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java new file mode 100644 index 00000000000..109cc279bdc --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/qianfan/QianFanParentProperties.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +/** + * @author Geng Rong + */ +class QianFanParentProperties { + + private String apiKey; + + private String secretKey; + + private String baseUrl; + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getSecretKey() { + return secretKey; + } + + public void setSecretKey(String secretKey) { + this.secretKey = secretKey; + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java new file mode 100644 index 00000000000..607259735f4 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java @@ -0,0 +1,97 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.qianfan.QianFanChatClient; +import org.springframework.ai.qianfan.QianFanEmbeddingClient; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), + @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) +public class QianFanAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(QianFanAutoConfigurationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.apiKey=" + System.getenv("QIANFAN_API_KEY"), + "spring.ai.qianfan.secretKey=" + System.getenv("QIANFAN_SECRET_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)); + + @Test + void generate() { + contextRunner.run(context -> { + QianFanChatClient client = context.getBean(QianFanChatClient.class); + String response = client.call("Hello"); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void generateStreaming() { + contextRunner.run(context -> { + QianFanChatClient client = context.getBean(QianFanChatClient.class); + Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); + String response = Objects.requireNonNull(responseFlux.collectList().block()) + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getContent()) + .collect(Collectors.joining()); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void embedding() { + contextRunner.run(context -> { + QianFanEmbeddingClient embeddingClient = context.getBean(QianFanEmbeddingClient.class); + + EmbeddingResponse embeddingResponse = embeddingClient + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + + assertThat(embeddingClient.dimensions()).isEqualTo(1024); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java new file mode 100644 index 00000000000..22fbc660aee --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanPropertiesTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.qianfan; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.qianfan.QianFanChatClient; +import org.springframework.ai.qianfan.QianFanEmbeddingClient; +import org.springframework.ai.qianfan.api.QianFanApi; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for + * {@link org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties}, + * {@link org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties} and + * {@link org.springframework.ai.autoconfigure.qianfan.QianFanEmbeddingProperties}. + * + * @author Geng Rong + */ +public class QianFanPropertiesTests { + + @Test + public void chatProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.qianfan.base-url=TEST_BASE_URL", + "spring.ai.qianfan.api-key=abc123", + "spring.ai.qianfan.secret-key=def123", + "spring.ai.qianfan.chat.options.model=MODEL_XYZ", + "spring.ai.qianfan.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(QianFanChatProperties.class); + var connectionProperties = context.getBean(QianFanConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getSecretKey()).isEqualTo("def123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isNull(); + assertThat(chatProperties.getBaseUrl()).isNull(); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + + @Test + public void chatOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.qianfan.base-url=TEST_BASE_URL", + "spring.ai.qianfan.api-key=abc123", + "spring.ai.qianfan.secret-key=def123", + "spring.ai.qianfan.chat.base-url=TEST_BASE_URL2", + "spring.ai.qianfan.chat.api-key=456", + "spring.ai.qianfan.chat.secret-key=def456", + "spring.ai.qianfan.chat.options.model=MODEL_XYZ", + "spring.ai.qianfan.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(QianFanChatProperties.class); + var connectionProperties = context.getBean(QianFanConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getSecretKey()).isEqualTo("def123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isEqualTo("456"); + assertThat(chatProperties.getSecretKey()).isEqualTo("def456"); + assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + + @Test + public void embeddingProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.qianfan.base-url=TEST_BASE_URL", + "spring.ai.qianfan.api-key=abc123", + "spring.ai.qianfan.secret-key=def123", + "spring.ai.qianfan.embedding.options.model=MODEL_XYZ") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + var embeddingProperties = context.getBean(QianFanEmbeddingProperties.class); + var connectionProperties = context.getBean(QianFanConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getSecretKey()).isEqualTo("def123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(embeddingProperties.getApiKey()).isNull(); + assertThat(embeddingProperties.getBaseUrl()).isNull(); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void embeddingOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.qianfan.base-url=TEST_BASE_URL", + "spring.ai.qianfan.api-key=abc123", + "spring.ai.qianfan.secret-key=def123", + "spring.ai.qianfan.embedding.base-url=TEST_BASE_URL2", + "spring.ai.qianfan.embedding.api-key=456", + "spring.ai.qianfan.embedding.secret-key=def456", + "spring.ai.qianfan.embedding.options.model=MODEL_XYZ") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + var embeddingProperties = context.getBean(QianFanEmbeddingProperties.class); + var connectionProperties = context.getBean(QianFanConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getSecretKey()).isEqualTo("def123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(embeddingProperties.getApiKey()).isEqualTo("456"); + assertThat(embeddingProperties.getSecretKey()).isEqualTo("def456"); + assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + public void chatOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.qianfan.api-key=API_KEY", + "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL", + + "spring.ai.qianfan.chat.options.model=MODEL_XYZ", + "spring.ai.qianfan.chat.options.frequencyPenalty=-1.5", + "spring.ai.qianfan.chat.options.logitBias.myTokenId=-5", + "spring.ai.qianfan.chat.options.maxTokens=123", + "spring.ai.qianfan.chat.options.presencePenalty=0", + "spring.ai.qianfan.chat.options.responseFormat.type=json", + "spring.ai.qianfan.chat.options.stop=boza,koza", + "spring.ai.qianfan.chat.options.temperature=0.55", + "spring.ai.qianfan.chat.options.topP=0.56" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(QianFanChatProperties.class); + var connectionProperties = context.getBean(QianFanConnectionProperties.class); + var embeddingProperties = context.getBean(QianFanEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getSecretKey()).isEqualTo("SECRET_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("bge_large_zh"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getResponseFormat()) + .isEqualTo(new QianFanApi.ChatCompletionRequest.ResponseFormat("json")); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f); + }); + } + + @Test + public void embeddingOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.qianfan.api-key=API_KEY", + "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL", + + "spring.ai.qianfan.embedding.options.model=MODEL_XYZ", + "spring.ai.qianfan.embedding.options.encodingFormat=MyEncodingFormat" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + var connectionProperties = context.getBean(QianFanConnectionProperties.class); + var embeddingProperties = context.getBean(QianFanEmbeddingProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getSecretKey()).isEqualTo("SECRET_KEY"); + + assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + }); + } + + @Test + void embeddingActivation() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.api-key=API_KEY", "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL", "spring.ai.qianfan.embedding.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(QianFanEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(QianFanEmbeddingClient.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.api-key=API_KEY", "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(QianFanEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(QianFanEmbeddingClient.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.api-key=API_KEY", "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL", "spring.ai.qianfan.embedding.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(QianFanEmbeddingProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(QianFanEmbeddingClient.class)).isNotEmpty(); + }); + } + + @Test + void chatActivation() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.api-key=API_KEY", "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL", "spring.ai.qianfan.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(QianFanChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(QianFanChatClient.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.api-key=API_KEY", "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(QianFanChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(QianFanChatClient.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.qianfan.api-key=API_KEY", "spring.ai.qianfan.secret-key=SECRET_KEY", + "spring.ai.qianfan.base-url=TEST_BASE_URL", "spring.ai.qianfan.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, QianFanAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(QianFanChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(QianFanChatClient.class)).isNotEmpty(); + }); + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml new file mode 100644 index 00000000000..e8a3124671c --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-qianfan/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-qianfan-spring-boot-starter + jar + Spring AI Starter - QianFan + Spring AI QianFan Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-qianfan + ${project.parent.version} + + + +