From c34da1012cfddd45743338139693c97f39623b52 Mon Sep 17 00:00:00 2001 From: GR Date: Fri, 10 May 2024 09:07:37 +0800 Subject: [PATCH] feat: add DeepSeek model client --- README.md | 9 +- models/spring-ai-deepseek/README.md | 1 + models/spring-ai-deepseek/pom.xml | 58 +++ .../ai/deepseek/DeepSeekChatModel.java | 227 +++++++++ .../ai/deepseek/DeepSeekChatOptions.java | 335 ++++++++++++ .../ai/deepseek/aot/DeepSeekRuntimeHints.java | 42 ++ .../ai/deepseek/api/DeepSeekApi.java | 475 ++++++++++++++++++ .../ai/deepseek/api/DeepSeekApiConstants.java | 25 + .../DeepSeekChatResponseMetadata.java | 84 ++++ .../deepseek/metadata/DeepSeekRateLimit.java | 88 ++++ .../ai/deepseek/metadata/DeepSeekUsage.java | 62 +++ .../resources/META-INF/spring/aot.factories | 2 + .../DeepSeekChatCompletionRequestTests.java | 53 ++ .../deepseek/DeepSeekTestConfiguration.java | 48 ++ .../aot/DeepSeekRuntimeHintsTests.java | 46 ++ .../ai/deepseek/api/DeepSeekApiIT.java | 57 +++ .../ai/deepseek/chat/ActorsFilms.java | 53 ++ .../ai/deepseek/chat/DeepSeekChatModelIT.java | 192 +++++++ .../ai/deepseek/chat/DeepSeekRetryTests.java | 143 ++++++ .../test/resources/prompts/system-message.st | 4 + pom.xml | 2 + spring-ai-bom/pom.xml | 12 + .../src/main/antora/modules/ROOT/nav.adoc | 1 + .../ROOT/pages/api/chat/deepseek-chat.adoc | 249 +++++++++ spring-ai-spring-boot-autoconfigure/pom.xml | 8 + .../deepseek/DeepSeekAutoConfiguration.java | 68 +++ .../deepseek/DeepSeekChatProperties.java | 61 +++ .../DeepSeekConnectionProperties.java | 34 ++ .../deepseek/DeepSeekParentProperties.java | 45 ++ .../deepseek/DeepSeekAutoConfigurationIT.java | 73 +++ .../deepseek/DeepSeekPropertiesTests.java | 159 ++++++ .../spring-ai-starter-deepseek/pom.xml | 42 ++ 32 files changed, 2754 insertions(+), 4 deletions(-) create mode 100644 models/spring-ai-deepseek/README.md create mode 100644 models/spring-ai-deepseek/pom.xml create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApiConstants.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekChatResponseMetadata.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekRateLimit.java create mode 100644 models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java create mode 100644 models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java create mode 100644 models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekRetryTests.java create mode 100644 models/spring-ai-deepseek/src/test/resources/prompts/system-message.st create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekConnectionProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml diff --git a/README.md b/README.md index cfe1218382..5e053adafe 100644 --- a/README.md +++ b/README.md @@ -22,16 +22,16 @@ For further information go to our [Spring AI reference documentation](https://do ## Educational Resources - Follow the [Workshop material for Azure OpenAI](https://github.com/Azure-Samples/spring-ai-azure-workshop) - - The workshop contains step-by-step examples from 'hello world' to 'retrieval augmented generation' + - The workshop contains step-by-step examples from 'hello world' to 'retrieval augmented generation' Some selected videos. Search YouTube! for more. - Spring Tips: Spring AI -
[![Watch Spring Tips video](https://img.youtube.com/vi/aNKDoiOUo9M/default.jpg)](https://www.youtube.com/watch?v=aNKDoiOUo9M) +
[![Watch Spring Tips video](https://img.youtube.com/vi/aNKDoiOUo9M/default.jpg)](https://www.youtube.com/watch?v=aNKDoiOUo9M) * Overview of Spring AI @ Devoxx 2023 -
[![Watch the Devoxx 2023 video](https://img.youtube.com/vi/7OY9fKVxAFQ/default.jpg)](https://www.youtube.com/watch?v=7OY9fKVxAFQ) +
[![Watch the Devoxx 2023 video](https://img.youtube.com/vi/7OY9fKVxAFQ/default.jpg)](https://www.youtube.com/watch?v=7OY9fKVxAFQ) * Introducing Spring AI - Add Generative AI to your Spring Applications -
[![Watch the video](https://img.youtube.com/vi/1g_wuincUdU/default.jpg)](https://www.youtube.com/watch?v=1g_wuincUdU) +
[![Watch the video](https://img.youtube.com/vi/1g_wuincUdU/default.jpg)](https://www.youtube.com/watch?v=1g_wuincUdU) ## Getting Started @@ -98,6 +98,7 @@ Spring AI supports many AI models. For an overview see here. Specific models c * Transformers (ONNX) * Anthropic Claude3 * MiniMax +* DeepSeek **Prompts:** Central to AI model interaction is the Prompt, which provides specific instructions for the AI to act upon. diff --git a/models/spring-ai-deepseek/README.md b/models/spring-ai-deepseek/README.md new file mode 100644 index 0000000000..2a08452511 --- /dev/null +++ b/models/spring-ai-deepseek/README.md @@ -0,0 +1 @@ +[DeepSeek Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/deepseek-chat.html) \ No newline at end of file diff --git a/models/spring-ai-deepseek/pom.xml b/models/spring-ai-deepseek/pom.xml new file mode 100644 index 0000000000..8b1c4d5479 --- /dev/null +++ b/models/spring-ai-deepseek/pom.xml @@ -0,0 +1,58 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-deepseek + jar + Spring AI DeepSeek + DeepSeek support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java new file mode 100644 index 0000000000..a48d1b9743 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -0,0 +1,227 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion.Choice; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest; +import org.springframework.ai.deepseek.metadata.DeepSeekChatResponseMetadata; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author Geng Rong + */ +public class DeepSeekChatModel implements ChatModel, StreamingChatModel { + + private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModel.class); + + /** + * The default options used for the chat completion requests. + */ + private final DeepSeekChatOptions defaultOptions; + + /** + * The retry template used to retry the DeepSeek API calls. + */ + public final RetryTemplate retryTemplate; + + /** + * Low-level access to the DeepSeek API. + */ + private final DeepSeekApi deepSeekApi; + + /** + * Creates an instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @throws IllegalArgumentException if deepSeekApi is null + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi) { + this(deepSeekApi, + DeepSeekChatOptions.builder().withModel(DeepSeekApi.DEFAULT_CHAT_MODEL).withTemperature(1F).build()); + } + + /** + * Initializes an instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @param options The DeepSeekChatOptions to configure the chat client. + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions options) { + this(deepSeekApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + /** + * Initializes a new instance of the DeepSeekChatModel. + * @param deepSeekApi The DeepSeekApi instance to be used for interacting with the + * DeepSeek Chat API. + * @param options The DeepSeekChatOptions to configure the chat client. + * @param retryTemplate The retry template. + */ + public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions options, RetryTemplate retryTemplate) { + Assert.notNull(deepSeekApi, "DeepSeekApi must not be null"); + Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + this.deepSeekApi = deepSeekApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + + ChatCompletionRequest request = createRequest(prompt, false); + + return this.retryTemplate.execute(ctx -> { + + ResponseEntity completionEntity = this.doChatCompletion(request); + + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List generations = chatCompletion.choices() + .stream() + .map(choice -> new Generation(choice.message().content(), toMap(chatCompletion.id(), choice)) + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null))) + .toList(); + + return new ChatResponse(generations, DeepSeekChatResponseMetadata.from(completionEntity.getBody())); + }); + } + + @Override + public ChatOptions getDefaultOptions() { + return DeepSeekChatOptions.fromOptions(this.defaultOptions); + } + + private Map toMap(String id, ChatCompletion.Choice choice) { + Map map = new HashMap<>(); + + var message = choice.message(); + if (message.role() != null) { + map.put("role", message.role().name()); + } + if (choice.finishReason() != null) { + map.put("finishReason", choice.finishReason().name()); + } + map.put("id", id); + return map; + } + + @Override + public Flux stream(Prompt prompt) { + + ChatCompletionRequest request = createRequest(prompt, true); + return retryTemplate.execute(ctx -> { + var completionChunks = this.deepSeekApi.chatCompletionStream(request); + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + return completionChunks.map(this::chunkToChatCompletion).map(chatCompletion -> { + String id = chatCompletion.id(); + + List generations = chatCompletion.choices().stream().map(choice -> { + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + String finish = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generation = new Generation(choice.message().content(), + Map.of("id", id, "role", roleMap.get(id), "finishReason", finish)); + if (choice.finishReason() != null) { + generation = generation + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)); + } + return generation; + }).toList(); + return new ChatResponse(generations); + }); + }); + } + + /** + * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. + * @param chunk the ChatCompletionChunk to convert + * @return the ChatCompletion + */ + private DeepSeekApi.ChatCompletion chunkToChatCompletion(DeepSeekApi.ChatCompletionChunk chunk) { + List choices = chunk.choices() + .stream() + .map(cc -> new Choice(cc.finishReason(), cc.index(), cc.delta(), cc.logprobs())) + .toList(); + + return new DeepSeekApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), + chunk.systemFingerprint(), "chat.completion", null); + } + + protected ResponseEntity doChatCompletion(ChatCompletionRequest request) { + return this.deepSeekApi.chatCompletionEntity(request); + } + + /** + * Accessible for testing. + */ + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + List chatCompletionMessages = prompt.getInstructions() + .stream() + .map(m -> new ChatCompletionMessage(m.getContent(), Role.valueOf(m.getMessageType().name()))) + .toList(); + + ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + DeepSeekChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, DeepSeekChatOptions.class); + + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + if (this.defaultOptions != null) { + request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); + } + return request; + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java new file mode 100644 index 0000000000..1a540aa11d --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java @@ -0,0 +1,335 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +import java.util.List; + +/** + * Chat completions options for the DeepSeek chat API. + * DeepSeek + * chat completion + * + * @author Geng Rong + */ +@JsonInclude(Include.NON_NULL) +public class DeepSeekChatOptions implements ChatOptions { + + // @formatter:off + /** + * ID of the model to use. You can use either usedeepseek-coder or deepseek-chat. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Float frequencyPenalty; + /** + * The maximum number of tokens that can be generated in the chat completion. + * The total length of input tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Float presencePenalty; + /** + * A string or a list containing up to 4 strings, upon encountering these words, the API will cease generating more tokens. + */ + @NestedConfigurationProperty + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 2. + * Higher values like 0.8 will make the output more random, + * while lower values like 0.2 will make it more focused and deterministic. + * We generally recommend altering this or top_p but not both. + */ + private @JsonProperty("temperature") Float temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p probability mass. + * So 0.1 means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Float topP; + /** + * Whether to return log probabilities of the output tokens or not. + * If true, returns the log probabilities of each output token returned in the content of message. + */ + private @JsonProperty("logprobs") Boolean logprobs; + /** + * An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. logprobs must be set to true if this parameter is used. + */ + private @JsonProperty("top_logprobs") Integer topLogprobs; + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + protected DeepSeekChatOptions options; + + public Builder() { + this.options = new DeepSeekChatOptions(); + } + + public Builder(DeepSeekChatOptions options) { + this.options = options; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withFrequencyPenalty(Float frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + this.options.topLogprobs = topLogprobs; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder withPresencePenalty(Float presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder withStop(List stop) { + this.options.stop = stop; + return this; + } + + public Builder withTemperature(Float temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.options.topP = topP; + return this; + } + + public DeepSeekChatOptions build() { + return this.options; + } + + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public Float getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Boolean getLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogprobs() { + return this.topLogprobs; + } + + public void setTopLogprobs(Integer topLogprobs) { + this.topLogprobs = topLogprobs; + } + + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Float getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); + result = prime * result + ((logprobs == null) ? 0 : logprobs.hashCode()); + result = prime * result + ((topLogprobs == null) ? 0 : topLogprobs.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + DeepSeekChatOptions other = (DeepSeekChatOptions) obj; + if (this.model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) + return false; + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + return false; + if (this.logprobs == null) { + if (other.logprobs != null) + return false; + } + else if (!this.logprobs.equals(other.logprobs)) + return false; + if (this.topLogprobs == null) { + if (other.topLogprobs != null) + return false; + } + else if (!this.topLogprobs.equals(other.topLogprobs)) + return false; + if (this.maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!this.maxTokens.equals(other.maxTokens)) + return false; + if (this.presencePenalty == null) { + if (other.presencePenalty != null) + return false; + } + else if (!this.presencePenalty.equals(other.presencePenalty)) + return false; + if (this.stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (this.temperature == null) { + if (other.temperature != null) + return false; + } + else if (!this.temperature.equals(other.temperature)) + return false; + if (this.topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + return true; + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @JsonIgnore + public void setTopK(Integer topK) { + throw new UnsupportedOperationException("Unimplemented method 'setTopK'"); + } + + public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { + return builder().withModel(fromOptions.getModel()) + .withFrequencyPenalty(fromOptions.getFrequencyPenalty()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withPresencePenalty(fromOptions.getPresencePenalty()) + .withStop(fromOptions.getStop()) + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withLogprobs(fromOptions.getLogprobs()) + .withTopLogprobs(fromOptions.getTopLogprobs()) + .build(); + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java new file mode 100644 index 0000000000..22d9ce8b56 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHints.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.aot; + +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; + +/** + * The DeepSeekRuntimeHints class is responsible for registering runtime hints for + * DeepSeek API classes. + * + * @author Geng Rong + */ +public class DeepSeekRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + for (var tr : findJsonAnnotatedClassesInPackage(DeepSeekApi.class)) + hints.reflection().registerType(tr, mcs); + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java new file mode 100644 index 0000000000..9bd318f1a7 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java @@ -0,0 +1,475 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.util.api.ApiUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.List; +import java.util.function.Predicate; + +import static org.springframework.ai.deepseek.api.DeepSeekApiConstants.DEFAULT_BASE_URL; + +// @formatter:off +/** + * Single class implementation of the DeepSeek Chat Completion API: https://platform.deepseek.com/api-docs/api/create-chat-completion + * + * @author Geng Rong + */ +public class DeepSeekApi { + + public static final String DEFAULT_CHAT_MODEL = ChatModel.DEEPSEEK_CHAT.getValue(); + + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + + private final RestClient restClient; + + private final WebClient webClient; + + /** + * Create an new chat completion api with base URL set to https://api.deepseek.com + * + * @param deepseekToken DeepSeek apiKey. + */ + public DeepSeekApi(String deepseekToken) { + this(DEFAULT_BASE_URL, deepseekToken); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param deepseekToken DeepSeek apiKey. + */ + public DeepSeekApi(String baseUrl, String deepseekToken) { + this(baseUrl, deepseekToken, RestClient.builder()); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param deepseekToken DeepSeek apiKey. + * @param restClientBuilder RestClient builder. + */ + public DeepSeekApi(String baseUrl, String deepseekToken, RestClient.Builder restClientBuilder) { + this(baseUrl, deepseekToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param deepseekToken DeepSeek apiKey. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public DeepSeekApi(String baseUrl, String deepseekToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + this.restClient = restClientBuilder + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders(deepseekToken)) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = WebClient.builder() + .baseUrl(baseUrl) + .defaultHeaders(ApiUtils.getJsonContentHeaders(deepseekToken)) + .build(); + } + + /** + * DeepSeek Chat Completion Models + */ + public enum ChatModel { + /** + * The backend model of deepseek-chat has been updated to DeepSeek-V2, + * you can access DeepSeek-V2 without modification to the model name. + * The open-source DeepSeek-V2 model supports 128K context window, + * and DeepSeek-V2 on API/Web supports 32K context window. + * Context window: 32k tokens + */ + DEEPSEEK_CHAT("deepseek-chat"), + + /** + * DeepSeek Coder is composed of a series of code language models, good at coding tasks. + * Context window: 16K tokens + */ + DEEPSEEK_CODER("deepseek-coder"); + + public final String value; + + ChatModel(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** + * Creates a model response for the given chat conversation. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. You can use either usedeepseek-coder or deepseek-chat. + * @param frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + * @param maxTokens The maximum number of tokens that can be generated in the chat completion. + * The total length of input tokens and generated tokens is limited by the model's context length. + * @param presencePenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + * @param stop A string or a list containing up to 4 strings, upon encountering these words, + * the API will cease generating more tokens. + * @param stream If set, partial message deltas will be sent. + * Tokens will be sent as data-only server-sent events (SSE) as they become available, + * with the stream terminated by a data: [DONE] message. + * @param temperature What sampling temperature to use, between 0 and 2. + * Higher values like 0.8 will make the output more random, + * while lower values like 0.2 will make it more focused and deterministic. + * We generally recommend altering this or top_p but not both. + * @param topP An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p probability mass. + * So 0.1 means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + * @param logprobs Whether to return log probabilities of the output tokens or not. + * If true, returns the log probabilities of each output token returned in the content of message. + * @param topLogprobs An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. logprobs must be set to true if this parameter is used. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionRequest ( + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("frequency_penalty") Float frequencyPenalty, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("presence_penalty") Float presencePenalty, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Float temperature, + @JsonProperty("top_p") Float topP, + @JsonProperty("logprobs") Boolean logprobs, + @JsonProperty("top_logprobs") Integer topLogprobs) { + + /** + * Shortcut constructor for a chat completion request with the given messages and model. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. You can use either usedeepseek-coder or deepseek-chat. + * @param temperature What sampling temperature to use, between 0 and 1. + */ + public ChatCompletionRequest(List messages, String model, Float temperature) { + this(messages, model, null, null, null, null, false, temperature, null, + null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model and control for streaming. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. You can use either usedeepseek-coder or deepseek-chat. + * @param temperature What sampling temperature to use, between 0 and 1. + * @param stream If set, partial message deltas will be sent. + * * Tokens will be sent as data-only server-sent events (SSE) as they become available, + * * with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, String model, Float temperature, boolean stream) { + this(messages, model, null, null, null, null, stream, temperature, null, + null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model + * Streaming is set to false, temperature to 1 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param model ID of the model to use. You can use either usedeepseek-coder or deepseek-chat. + */ + public ChatCompletionRequest(List messages, String model) { + this(messages, model, null, null, null, null, false, 1F, null, + null, null); + } + + /** + * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. + * Streaming is set to false, temperature to 0.8 and all other parameters are null. + * + * @param messages A list of messages comprising the conversation so far. + * @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events + * as they become available, with the stream terminated by a data: [DONE] message. + */ + public ChatCompletionRequest(List messages, Boolean stream) { + this(messages, null, null, null, null, null, stream, null, null, + null, null); + } + } + + /** + * Message comprising the conversation. + * + * @param rawContent The contents of the message. Can be either a {@link String}. + * The response message content is always a {@link String}. + * @param role The role of the messages author. Could be one of the {@link Role} types. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionMessage( + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role) { + + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + + /** + * The role of the author of this message. + */ + public enum Role { + /** + * System message. + */ + @JsonProperty("system") SYSTEM, + /** + * User message. + */ + @JsonProperty("user") USER, + /** + * Assistant message. + */ + @JsonProperty("assistant") ASSISTANT + } + } + + /** + * The reason the model stopped generating tokens. + */ + public enum ChatCompletionFinishReason { + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") STOP, + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") LENGTH, + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("content_filter") CONTENT_FILTER + } + + /** + * Represents a chat completion response returned by model, based on the provided input. + * + * @param id A unique identifier for the chat completion. + * @param choices A list of chat completion choices. Can be more than one if n is greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was created. + * @param model The model used for the chat completion. + * @param systemFingerprint This fingerprint represents the backend configuration that the model runs with. Can be + * used in conjunction with the seed request parameter to understand when backend changes have been made that might + * impact determinism. + * @param object The object type, which is always chat.completion. + * @param usage Usage statistics for the completion request. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletion( + @JsonProperty("id") String id, + @JsonProperty("choices") List choices, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("object") String object, + @JsonProperty("usage") Usage usage) { + + /** + * Chat completion choice. + * + * @param finishReason The reason the model stopped generating tokens. + * @param index The index of the choice in the list of choices. + * @param message A chat completion message generated by the model. + * @param logprobs Log probability information for the choice. + */ + @JsonInclude(Include.NON_NULL) + public record Choice( + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("message") ChatCompletionMessage message, + @JsonProperty("logprobs") LogProbs logprobs) { + + } + } + + /** + * Log probability information for the choice. + * + * @param content A list of message content tokens with log probability information. + */ + @JsonInclude(Include.NON_NULL) + public record LogProbs( + @JsonProperty("content") List content) { + + /** + * Message content tokens with log probability information. + * + * @param token The token. + * @param logprob The log probability of the token. + * @param probBytes A list of integers representing the UTF-8 bytes representation of the token. + * Useful in instances where characters are represented by multiple tokens and their byte + * representations must be combined to generate the correct text representation. + * Can be null if there is no bytes representation for the token. + * @param topLogprobs List of the most likely tokens and their log probability, + * at this token position. In rare cases, there may be fewer than the number of + * requested top_logprobs returned. + */ + @JsonInclude(Include.NON_NULL) + public record Content( + @JsonProperty("token") String token, + @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes, + @JsonProperty("top_logprobs") List topLogprobs) { + + /** + * The most likely tokens and their log probability, at this token position. + * + * @param token The token. + * @param logprob TThe log probability of this token, + * if it is within the top 20 most likely tokens. Otherwise, + * the value -9999.0 is used to signify that the token is very unlikely. + * @param probBytes A list of integers representing the UTF-8 bytes representation of the token. + * Useful in instances where characters are represented by multiple tokens and their byte + * representations must be combined to generate the correct text representation. + * Can be null if there is no bytes representation for the token. + */ + @JsonInclude(Include.NON_NULL) + public record TopLogProbs( + @JsonProperty("token") String token, + @JsonProperty("logprob") Float logprob, + @JsonProperty("bytes") List probBytes) { + } + } + } + + /** + * Usage statistics for the completion request. + * + * @param completionTokens Number of tokens in the generated completion. Only applicable for completion requests. + * @param promptTokens Number of tokens in the prompt. + * @param totalTokens Total number of tokens used in the request (prompt + completion). + */ + @JsonInclude(Include.NON_NULL) + public record Usage( + @JsonProperty("completion_tokens") Integer completionTokens, + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens) { + + } + + /** + * Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + * + * @param id A unique identifier for the chat completion. Each chunk has the same ID. + * @param choices A list of chat completion choices. Can be more than one if n is greater than 1. + * @param created The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same + * timestamp. + * @param model The model used for the chat completion. + * @param object The object type, which is always 'chat.completion.chunk'. + */ + @JsonInclude(Include.NON_NULL) + public record ChatCompletionChunk( + @JsonProperty("id") String id, + @JsonProperty("choices") List choices, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("system_fingerprint") String systemFingerprint, + @JsonProperty("object") String object) { + + /** + * Chat completion choice. + * + * @param finishReason The reason the model stopped generating tokens. + * @param index The index of the choice in the list of choices. + * @param delta A chat completion delta generated by streamed model responses. + * @param logprobs Log probability information for the choice. + */ + @JsonInclude(Include.NON_NULL) + public record ChunkChoice( + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ChatCompletionMessage delta, + @JsonProperty("logprobs") LogProbs logprobs) { + } + } + + /** + * Creates a model response for the given chat conversation. + * + * @param chatRequest The chat completion request. + * @return Entity response with {@link ChatCompletion} as a body and HTTP status code and headers. + */ + public ResponseEntity chatCompletionEntity(ChatCompletionRequest chatRequest) { + + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(!chatRequest.stream(), "Request must set the steam property to false."); + + return this.restClient.post() + .uri("/chat/completions") + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletion.class); + } + + /** + * Creates a streaming chat response for the given chat conversation. + * + * @param chatRequest The chat completion request. Must have the stream property set to true. + * @return Returns a {@link Flux} stream from chat completion chunks. + */ + public Flux chatCompletionStream(ChatCompletionRequest chatRequest) { + Assert.notNull(chatRequest, "The request body can not be null."); + Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true."); + + return this.webClient.post() + .uri("/chat/completions") + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)); + } +} +// @formatter:on diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApiConstants.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApiConstants.java new file mode 100644 index 0000000000..73f0115694 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApiConstants.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.api; + +/** + * @author Geng Rong + */ +public class DeepSeekApiConstants { + + public static final String DEFAULT_BASE_URL = "https://api.deepseek.com"; + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekChatResponseMetadata.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekChatResponseMetadata.java new file mode 100644 index 0000000000..19258f3337 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekChatResponseMetadata.java @@ -0,0 +1,84 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.metadata; + +import org.springframework.ai.chat.metadata.*; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.HashMap; + +/** + * {@link ChatResponseMetadata} implementation for {@literal DeepSeek}. + * + * @author Geng Rong + */ +public class DeepSeekChatResponseMetadata extends HashMap implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; + + public static DeepSeekChatResponseMetadata from(DeepSeekApi.ChatCompletion result) { + Assert.notNull(result, "DeepSeek ChatCompletionResult must not be null"); + DeepSeekUsage usage = DeepSeekUsage.from(result.usage()); + return new DeepSeekChatResponseMetadata(result.id(), usage); + } + + private final String id; + + @Nullable + private RateLimit rateLimit; + + private final Usage usage; + + protected DeepSeekChatResponseMetadata(String id, DeepSeekUsage usage) { + this(id, usage, null); + } + + protected DeepSeekChatResponseMetadata(String id, DeepSeekUsage usage, @Nullable DeepSeekRateLimit rateLimit) { + this.id = id; + this.usage = usage; + this.rateLimit = rateLimit; + } + + public String getId() { + return this.id; + } + + @Override + @Nullable + public RateLimit getRateLimit() { + RateLimit rateLimit = this.rateLimit; + return rateLimit != null ? rateLimit : new EmptyRateLimit(); + } + + @Override + public Usage getUsage() { + Usage usage = this.usage; + return usage != null ? usage : new EmptyUsage(); + } + + public DeepSeekChatResponseMetadata withRateLimit(RateLimit rateLimit) { + this.rateLimit = rateLimit; + return this; + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getRateLimit()); + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekRateLimit.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekRateLimit.java new file mode 100644 index 0000000000..7a520f2106 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekRateLimit.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.metadata; + +import org.springframework.ai.chat.metadata.RateLimit; + +import java.time.Duration; + +/** + * @author Geng Rong + */ +public class DeepSeekRateLimit implements RateLimit { + + private static final String RATE_LIMIT_STRING = "{ @type: %1$s, requestsLimit: %2$s, requestsRemaining: %3$s, requestsReset: %4$s, tokensLimit: %5$s; tokensRemaining: %6$s; tokensReset: %7$s }"; + + private final Long requestsLimit; + + private final Long requestsRemaining; + + private final Long tokensLimit; + + private final Long tokensRemaining; + + private final Duration requestsReset; + + private final Duration tokensReset; + + public DeepSeekRateLimit(Long requestsLimit, Long requestsRemaining, Duration requestsReset, Long tokensLimit, + Long tokensRemaining, Duration tokensReset) { + + this.requestsLimit = requestsLimit; + this.requestsRemaining = requestsRemaining; + this.requestsReset = requestsReset; + this.tokensLimit = tokensLimit; + this.tokensRemaining = tokensRemaining; + this.tokensReset = tokensReset; + } + + @Override + public Long getRequestsLimit() { + return this.requestsLimit; + } + + @Override + public Long getTokensLimit() { + return this.tokensLimit; + } + + @Override + public Long getRequestsRemaining() { + return this.requestsRemaining; + } + + @Override + public Long getTokensRemaining() { + return this.tokensRemaining; + } + + @Override + public Duration getRequestsReset() { + return this.requestsReset; + } + + @Override + public Duration getTokensReset() { + return this.tokensReset; + } + + @Override + public String toString() { + return RATE_LIMIT_STRING.formatted(getClass().getName(), getRequestsLimit(), getRequestsRemaining(), + getRequestsReset(), getTokensLimit(), getTokensRemaining(), getTokensReset()); + } + +} diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java new file mode 100644 index 0000000000..d6a3793b75 --- /dev/null +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/metadata/DeepSeekUsage.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.metadata; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.util.Assert; + +/** + * @author Geng Rong + */ +public class DeepSeekUsage implements Usage { + + public static DeepSeekUsage from(DeepSeekApi.Usage usage) { + return new DeepSeekUsage(usage); + } + + private final DeepSeekApi.Usage usage; + + protected DeepSeekUsage(DeepSeekApi.Usage usage) { + Assert.notNull(usage, "DeepSeek Usage must not be null"); + this.usage = usage; + } + + protected DeepSeekApi.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().promptTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return getUsage().completionTokens().longValue(); + } + + @Override + public Long getTotalTokens() { + return getUsage().totalTokens().longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 0000000000..112c3a5eeb --- /dev/null +++ b/models/spring-ai-deepseek/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.deepseek.aot.DeepSeekRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java new file mode 100644 index 0000000000..d1edd44884 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatCompletionRequestTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.api.DeepSeekApi; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +public class DeepSeekChatCompletionRequestTests { + + @Test + public void createRequestWithChatOptions() { + + var client = new DeepSeekChatModel(new DeepSeekApi("TEST"), + DeepSeekChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build()); + + var request = client.createRequest(new Prompt("Test message content"), false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isFalse(); + + assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(request.temperature()).isEqualTo(66.6f); + + request = client.createRequest(new Prompt("Test message content", + DeepSeekChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.stream()).isTrue(); + + assertThat(request.model()).isEqualTo("PROMPT_MODEL"); + assertThat(request.temperature()).isEqualTo(99.9f); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java new file mode 100644 index 0000000000..63d264edba --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekTestConfiguration.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek; + +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * @author Geng Rong + */ +@SpringBootConfiguration +public class DeepSeekTestConfiguration { + + @Bean + public DeepSeekApi deepSeekApi() { + return new DeepSeekApi(getApiKey()); + } + + private String getApiKey() { + String apiKey = System.getenv("DEEPSEEK_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name DEEPSEEK_API_KEY"); + } + return apiKey; + } + + @Bean + public DeepSeekChatModel deepSeekChatModel(DeepSeekApi api) { + return new DeepSeekChatModel(api); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java new file mode 100644 index 0000000000..089db11712 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/aot/DeepSeekRuntimeHintsTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.aot; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; + +import java.util.Set; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; + +/** + * @author Geng Rong + */ +class DeepSeekRuntimeHintsTests { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + DeepSeekRuntimeHints deepSeekRuntimeHints = new DeepSeekRuntimeHints(); + deepSeekRuntimeHints.registerHints(runtimeHints, null); + + Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(DeepSeekApi.class); + for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { + assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); + } + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java new file mode 100644 index 0000000000..9b18dc409a --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.deepseek.api.DeepSeekApi.*; +import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") +public class DeepSeekApiIT { + + DeepSeekApi DeepSeekApi = new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + + @Test + void chatCompletionEntity() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + ResponseEntity response = DeepSeekApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1F, false)); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + } + + @Test + void chatCompletionStream() { + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); + Flux response = DeepSeekApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1F, true)); + + assertThat(response).isNotNull(); + assertThat(response.collectList().block()).isNotNull(); + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java new file mode 100644 index 0000000000..53f529ef3e --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/ActorsFilms.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.chat; + +import java.util.List; + +/** + * @author Geng Rong + */ +public class ActorsFilms { + + private String actor; + + private List movies; + + public ActorsFilms() { + } + + public String getActor() { + return actor; + } + + public void setActor(String actor) { + this.actor = actor; + } + + public List getMovies() { + return movies; + } + + public void setMovies(List movies) { + this.movies = movies; + } + + @Override + public String toString() { + return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}'; + } + +} diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java new file mode 100644 index 0000000000..8a74d7c576 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekChatModelIT.java @@ -0,0 +1,192 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.chat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.deepseek.DeepSeekTestConfiguration; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@SpringBootTest(classes = DeepSeekTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") +class DeepSeekChatModelIT { + + @Autowired + protected ChatModel chatModel; + + @Autowired + protected StreamingChatModel streamingChatModel; + + private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModelIT.class); + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); + // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Please provide the JSON response without any code block markers such as ```json```. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputConverter() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography for a random actor. + Please provide the JSON response without any code block markers such as ```json```. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + Please provide the JSON response without any code block markers such as ```json```. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + Please provide the JSON response without any code block markers such as ```json```. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = streamingChatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + +} \ No newline at end of file diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekRetryTests.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekRetryTests.java new file mode 100644 index 0000000000..78922e88c6 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/chat/DeepSeekRetryTests.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.deepseek.chat; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.ai.deepseek.api.DeepSeekApi.*; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * @author Geng Rong + */ +@ExtendWith(MockitoExtension.class) +public class DeepSeekRetryTests { + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private @Mock DeepSeekApi deepSeekApi; + + private DeepSeekChatModel chatModel; + + @BeforeEach + public void beforeEach() { + RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + chatModel = new DeepSeekChatModel(deepSeekApi, DeepSeekChatOptions.builder().build(), retryTemplate); + } + + @Test + public void deepSeekChatTransientError() { + + var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", ChatCompletionMessage.Role.ASSISTANT), null); + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666L, "model", null, null, + new DeepSeekApi.Usage(10, 10, 10)); + + when(deepSeekApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = chatModel.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void deepSeekChatNonTransientError() { + when(deepSeekApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text"))); + } + + @Test + public void deepSeekChatStreamTransientError() { + + var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", ChatCompletionMessage.Role.ASSISTANT), null); + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666L, "model", null, + null); + + when(deepSeekApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(Flux.just(expectedChatCompletion)); + + var result = chatModel.stream(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(Objects.requireNonNull(result.collectList().block()).get(0).getResult().getOutput().getContent()) + .isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void deepSeekChatStreamNonTransientError() { + when(deepSeekApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text"))); + } + +} diff --git a/models/spring-ai-deepseek/src/test/resources/prompts/system-message.st b/models/spring-ai-deepseek/src/test/resources/prompts/system-message.st new file mode 100644 index 0000000000..dc2cf2dcd8 --- /dev/null +++ b/models/spring-ai-deepseek/src/test/resources/prompts/system-message.st @@ -0,0 +1,4 @@ +"You are a helpful AI assistant. Your name is {name}. +You are an AI assistant that helps people find information. +Your name is {name} +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 601e92da34..e60517f75b 100644 --- a/pom.xml +++ b/pom.xml @@ -74,6 +74,7 @@ models/spring-ai-vertex-ai-palm2 models/spring-ai-watsonx-ai models/spring-ai-zhipuai + models/spring-ai-deepseek spring-ai-spring-boot-starters/spring-ai-starter-anthropic spring-ai-spring-boot-starters/spring-ai-starter-azure-openai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai @@ -89,6 +90,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-palm2 spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-zhipuai + spring-ai-spring-boot-starters/spring-ai-starter-deepseek vector-stores/spring-ai-opensearch-store spring-ai-spring-boot-starters/spring-ai-starter-opensearch-store diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index b653d1e7e0..dd86a0a482 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -344,6 +344,12 @@ ${project.version} + + org.springframework.ai + spring-ai-deepseek + ${project.version} + + org.springframework.ai spring-ai-typesense-store @@ -451,6 +457,12 @@ spring-ai-zhipuai-spring-boot-starter ${project.version} + + + org.springframework.ai + spring-ai-deepseek-spring-boot-starter + ${project.version} + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index b554b8959c..86711402b6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -30,6 +30,7 @@ *** xref:api/chat/watsonx-ai-chat.adoc[Watsonx.AI] *** xref:api/chat/minimax-chat.adoc[MiniMax] **** xref:api/chat/functions/minimax-chat-functions.adoc[Function Calling] +*** xref:api/chat/deepseek-chat.adoc[DeepSeek] ** xref:api/embeddings.adoc[] *** xref:api/embeddings/openai-embeddings.adoc[OpenAI] *** xref:api/embeddings/ollama-embeddings.adoc[Ollama] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc new file mode 100644 index 0000000000..6e2dc6efcb --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/deepseek-chat.adoc @@ -0,0 +1,249 @@ += DeepSeek Chat + +Spring AI supports the various AI language models from DeepSeek. You can interact with DeepSeek language models and create a multilingual conversational assistant based on DeepSeek models. + +== Prerequisites + +You will need to create an API with DeepSeek to access DeepSeek language models. +Create an account at https://platform.deepseek.com/sign_up[DeepSeek registration page] and generate the token on the https://platform.deepseek.com/api_keys[API Keys page]. +The Spring AI project defines a configuration property named `spring.ai.deepseek.api-key` that you should set to the value of the `API Key` obtained from https://platform.deepseek.com/api_keys[API Keys page]. +Exporting an environment variable is one way to set that configuration property: + +[source,shell] +---- +export SPRING_AI_DEEPSEEK_AI_API_KEY= +---- + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. +Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the DeepSeek Chat Model. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-deepseek-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-deepseek-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the DeepSeek Chat model. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.deepseek` is used as the property prefix that lets you connect to DeepSeek. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.deepseek.base-url | The URL to connect to | https://api.deepseek.com +| spring.ai.deepseek.api-key | The API Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.deepseek.chat` is the property prefix that lets you configure the chat model implementation for DeepSeek. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.deepseek.chat.enabled | Enable DeepSeek chat model. | true +| spring.ai.deepseek.chat.base-url | Optional overrides the spring.ai.deepseek.base-url to provide chat specific url | - +| spring.ai.deepseek.chat.api-key | Optional overrides the spring.ai.deepseek.api-key to provide chat specific api-key | - +| spring.ai.deepseek.chat.options.model | ID of the model to use. You can use either use deepseek-coder or deepseek-chat. | deepseek-chat +| spring.ai.deepseek.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.deepseek.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.deepseek.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | 0.0f +| spring.ai.deepseek.chat.options.stop | Up to 4 sequences where the API will stop generating further tokens. | - +| spring.ai.deepseek.chat.options.temperature | What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. | 1.0F +| spring.ai.deepseek.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | 1.0F +| spring.ai.deepseek.chat.options.logprobs | Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message. | - +| spring.ai.deepseek.chat.options.topLogprobs | An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. | - +|==== + +NOTE: You can override the common `spring.ai.deepseek.base-url` and `spring.ai.deepseek.api-key` for the `ChatModel` implementations. +The `spring.ai.deepseek.chat.base-url` and `spring.ai.deepseek.chat.api-key` properties if set take precedence over the common properties. +This is useful if you want to use different DeepSeek accounts for different models and different model endpoints. + +TIP: All properties prefixed with `spring.ai.deepseek.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java[DeepSeekChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `DeepSeekChatModel(api, options)` constructor or the `spring.ai.deepseek.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates. Please provide the JSON response without any code block markers such as ```json```.", + DeepSeekChatOptions.builder() + .withModel(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .withTemperature(0.8f) + .build() + )); +---- + +TIP: In addition to the model specific link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java[DeepSeekChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller (Auto-configuration) + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-deepseek-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the DeepSeek Chat model: + +[source,application.properties] +---- +spring.ai.deepseek.api-key=YOUR_API_KEY +spring.ai.deepseek.chat.options.model=deepseek-chat +spring.ai.deepseek.chat.options.temperature=0.8 +---- + +TIP: replace the `api-key` with your DeepSeek credentials. + +This will create a `DeepSeekChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final DeepSeekChatModel chatModel; + + @Autowired + public ChatController(DeepSeekChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + var prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java[DeepSeekChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the DeepSeek service. + +Add the `spring-ai-deepseek` dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-deepseek + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-deepseek' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create a `DeepSeekChatModel` and use it for text generations: + +[source,java] +---- +var deepSeekApi = new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + +var chatModel = new DeepSeekChatModel(deepSeekApi, DeepSeekChatOptions.builder() + .withModel(DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue()) + .withTemperature(0.4f) + .withMaxTokens(200) + .build()); + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux streamResponse = chatModel.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +The `DeepSeekChatOptions` provides the configuration information for the chat requests. +The `DeepSeekChatOptions.Builder` is fluent options builder. + +=== Low-level DeepSeekApi Client [[low-level-api]] + +The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java[DeepSeekApi] provides is lightweight Java client for link:https://platform.deepseek.com/api-docs/[DeepSeek API]. + +Here is a simple snippet how to use the api programmatically: + +[source,java] +---- +DeepSeekApi deepSeekApi = + new DeepSeekApi(System.getenv("DEEPSEEK_API_KEY")); + +ChatCompletionMessage chatCompletionMessage = + new ChatCompletionMessage("Hello world", Role.USER); + +// Sync request +ResponseEntity response = deepSeekApi.chatCompletionEntity( + new ChatCompletionRequest(List.of(chatCompletionMessage), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue(), 0.7f, false)); + +// Streaming request +Flux streamResponse = deepSeekApi.chatCompletionStream( + new ChatCompletionRequest(List.of(chatCompletionMessage), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getValue(), 0.7f, true)); +---- + +Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/api/DeepSeekApi.java[DeepSeekApi.java]'s JavaDoc for further information. + +==== DeepSeekApi Samples +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/api/DeepSeekApiIT.java[DeepSeekApiIT.java] test provides some general examples how to use the lightweight library. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 291129335c..3c1fb13b8e 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -254,6 +254,14 @@ true + + + org.springframework.ai + spring-ai-deepseek + ${project.parent.version} + true + + org.springframework.boot spring-boot-configuration-processor diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java new file mode 100644 index 0000000000..21fd11a644 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfiguration.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.deepseek; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.ai.deepseek.api.DeepSeekApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * @author Geng Rong + */ +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(DeepSeekApi.class) +@EnableConfigurationProperties({ DeepSeekConnectionProperties.class, DeepSeekChatProperties.class }) +public class DeepSeekAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = DeepSeekChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public DeepSeekChatModel deepSeekChatModel(DeepSeekConnectionProperties commonProperties, + DeepSeekChatProperties chatProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + + var deepSeekApi = deepSeekApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), + chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); + + return new DeepSeekChatModel(deepSeekApi, chatProperties.getOptions(), retryTemplate); + } + + private DeepSeekApi deepSeekApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + + String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; + Assert.hasText(resolvedBaseUrl, "DeepSeek base URL must be set"); + + String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; + Assert.hasText(resolvedApiKey, "DeepSeek API key must be set"); + + return new DeepSeekApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java new file mode 100644 index 0000000000..fed7c880a2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekChatProperties.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.deepseek; + +import org.springframework.ai.deepseek.DeepSeekChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(DeepSeekChatProperties.CONFIG_PREFIX) +public class DeepSeekChatProperties extends DeepSeekParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.deepseek.chat"; + + public static final String DEFAULT_CHAT_MODEL = "deepseek-chat"; + + private static final Double DEFAULT_TEMPERATURE = 1D; + + /** + * Enable DeepSeek chat model. + */ + private boolean enabled = true; + + @NestedConfigurationProperty + private DeepSeekChatOptions options = DeepSeekChatOptions.builder() + .withModel(DEFAULT_CHAT_MODEL) + .withTemperature(DEFAULT_TEMPERATURE.floatValue()) + .build(); + + public DeepSeekChatOptions getOptions() { + return options; + } + + public void setOptions(DeepSeekChatOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekConnectionProperties.java new file mode 100644 index 0000000000..02f576d060 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekConnectionProperties.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.deepseek; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Geng Rong + */ +@ConfigurationProperties(DeepSeekConnectionProperties.CONFIG_PREFIX) +public class DeepSeekConnectionProperties extends DeepSeekParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.deepseek"; + + public static final String DEFAULT_BASE_URL = "https://api.deepseek.com"; + + public DeepSeekConnectionProperties() { + super.setBaseUrl(DEFAULT_BASE_URL); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java new file mode 100644 index 0000000000..6002fc2da9 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekParentProperties.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.deepseek; + +/** + * Internal parent properties for the DeepSeek properties. + * + * @author Geng Rong + */ +class DeepSeekParentProperties { + + private String apiKey; + + private String baseUrl; + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java new file mode 100644 index 0000000000..8f81129e59 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekAutoConfigurationIT.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.deepseek; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import reactor.core.publisher.Flux; + +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Geng Rong + */ +@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".*") +public class DeepSeekAutoConfigurationIT { + + private static final Log logger = LogFactory.getLog(DeepSeekAutoConfigurationIT.class); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.apiKey=" + System.getenv("DEEPSEEK_API_KEY")) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)); + + @Test + void generate() { + contextRunner.run(context -> { + DeepSeekChatModel client = context.getBean(DeepSeekChatModel.class); + String response = client.call("Hello"); + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + + @Test + void generateStreaming() { + contextRunner.run(context -> { + DeepSeekChatModel client = context.getBean(DeepSeekChatModel.class); + Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); + String response = responseFlux.collectList().block().stream().map(chatResponse -> { + return chatResponse.getResults().get(0).getOutput().getContent(); + }).collect(Collectors.joining()); + + assertThat(response).isNotEmpty(); + logger.info("Response: " + response); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java new file mode 100644 index 0000000000..f53d7a372f --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/deepseek/DeepSeekPropertiesTests.java @@ -0,0 +1,159 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.deepseek; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.deepseek.DeepSeekChatModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link DeepSeekConnectionProperties}, {@link DeepSeekChatProperties}. + * + * @author Geng Rong + */ +public class DeepSeekPropertiesTests { + + @Test + public void chatProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.api-key=abc123", + "spring.ai.deepseek.chat.options.model=MODEL_XYZ", + "spring.ai.deepseek.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(DeepSeekChatProperties.class); + var connectionProperties = context.getBean(DeepSeekConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isNull(); + assertThat(chatProperties.getBaseUrl()).isNull(); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + + @Test + public void chatOverrideConnectionProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.api-key=abc123", + "spring.ai.deepseek.chat.base-url=TEST_BASE_URL2", + "spring.ai.deepseek.chat.api-key=456", + "spring.ai.deepseek.chat.options.model=MODEL_XYZ", + "spring.ai.deepseek.chat.options.temperature=0.55") + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(DeepSeekChatProperties.class); + var connectionProperties = context.getBean(DeepSeekConnectionProperties.class); + + assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + + assertThat(chatProperties.getApiKey()).isEqualTo("456"); + assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + }); + } + + @Test + public void chatOptionsTest() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.deepseek.api-key=API_KEY", + "spring.ai.deepseek.base-url=TEST_BASE_URL", + + "spring.ai.deepseek.chat.options.model=MODEL_XYZ", + "spring.ai.deepseek.chat.options.frequencyPenalty=-1.5", + "spring.ai.deepseek.chat.options.maxTokens=123", + "spring.ai.deepseek.chat.options.presencePenalty=0", + "spring.ai.deepseek.chat.options.stop=boza,koza", + "spring.ai.deepseek.chat.options.temperature=0.55", + "spring.ai.deepseek.chat.options.topP=0.56" + ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + var chatProperties = context.getBean(DeepSeekChatProperties.class); + var connectionProperties = context.getBean(DeepSeekConnectionProperties.class); + + assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + + assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); + assertThat(chatProperties.getOptions().getFrequencyPenalty()).isEqualTo(-1.5f); + assertThat(chatProperties.getOptions().getMaxTokens()).isEqualTo(123); + assertThat(chatProperties.getOptions().getPresencePenalty()).isEqualTo(0); + assertThat(chatProperties.getOptions().getStop()).contains("boza", "koza"); + assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56f); + }); + } + + @Test + void chatActivation() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.api-key=API_KEY", "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(DeepSeekChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(DeepSeekChatModel.class)).isEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.api-key=API_KEY", "spring.ai.deepseek.base-url=TEST_BASE_URL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(DeepSeekChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(DeepSeekChatModel.class)).isNotEmpty(); + }); + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.deepseek.api-key=API_KEY", "spring.ai.deepseek.base-url=TEST_BASE_URL", + "spring.ai.deepseek.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, DeepSeekAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(DeepSeekChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(DeepSeekChatModel.class)).isNotEmpty(); + }); + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml new file mode 100644 index 0000000000..b01fe5c49e --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-deepseek/pom.xml @@ -0,0 +1,42 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-deepseek-spring-boot-starter + jar + Spring AI Starter - DeepSeek + Spring AI DeepSeek Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-deepseek + ${project.parent.version} + + + +