Skip to content

Commit

Permalink
feat: add DeepSeek model client
Browse files Browse the repository at this point in the history
  • Loading branch information
mxsl-gr committed May 10, 2024
1 parent f955fd7 commit cba666c
Show file tree
Hide file tree
Showing 32 changed files with 2,745 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Spring AI supports many AI models. For an overview see here. Specific models c
* PostgresML
* Transformers (ONNX)
* Anthropic Claude3
* DeepSeek


**Prompts:** Central to AI model interaction is the Prompt, which provides specific instructions for the AI to act upon.
Expand Down
1 change: 1 addition & 0 deletions models/spring-ai-deepseek/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[DeepSeek Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/deepseek-chat.html)
58 changes: 58 additions & 0 deletions models/spring-ai-deepseek/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai</artifactId>
<version>1.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>spring-ai-deepseek</artifactId>
<packaging>jar</packaging>
<name>Spring AI DeepSeek</name>
<description>DeepSeek support</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
<url>https://github.com/spring-projects/spring-ai</url>
<connection>git://github.com/spring-projects/spring-ai.git</connection>
<developerConnection>git@github.com:spring-projects/spring-ai.git</developerConnection>
</scm>

<dependencies>
<!-- production dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<!-- Spring Framework -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context-support</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-test</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion.Choice;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
import org.springframework.ai.deepseek.metadata.DeepSeekChatResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* @author Geng Rong
*/
public class DeepSeekChatClient implements ChatClient, StreamingChatClient {

private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatClient.class);

/**
* The default options used for the chat completion requests.
*/
private final DeepSeekChatOptions defaultOptions;

/**
* The retry template used to retry the DeepSeek API calls.
*/
public final RetryTemplate retryTemplate;

/**
* Low-level access to the DeepSeek API.
*/
private final DeepSeekApi deepSeekApi;

/**
* Creates an instance of the DeepSeekChatClient.
* @param deepSeekApi The DeepSeekApi instance to be used for interacting with the
* DeepSeek Chat API.
* @throws IllegalArgumentException if deepSeekApi is null
*/
public DeepSeekChatClient(DeepSeekApi deepSeekApi) {
this(deepSeekApi,
DeepSeekChatOptions.builder().withModel(DeepSeekApi.DEFAULT_CHAT_MODEL).withTemperature(1F).build());
}

/**
* Initializes an instance of the DeepSeekChatClient.
* @param deepSeekApi The DeepSeekApi instance to be used for interacting with the
* DeepSeek Chat API.
* @param options The DeepSeekChatOptions to configure the chat client.
*/
public DeepSeekChatClient(DeepSeekApi deepSeekApi, DeepSeekChatOptions options) {
this(deepSeekApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

/**
* Initializes a new instance of the DeepSeekChatClient.
* @param deepSeekApi The DeepSeekApi instance to be used for interacting with the
* DeepSeek Chat API.
* @param options The DeepSeekChatOptions to configure the chat client.
* @param retryTemplate The retry template.
*/
public DeepSeekChatClient(DeepSeekApi deepSeekApi, DeepSeekChatOptions options, RetryTemplate retryTemplate) {
Assert.notNull(deepSeekApi, "DeepSeekApi must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.deepSeekApi = deepSeekApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);

return this.retryTemplate.execute(ctx -> {

ResponseEntity<ChatCompletion> completionEntity = this.doChatCompletion(request);

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

List<Generation> generations = chatCompletion.choices()
.stream()
.map(choice -> new Generation(choice.message().content(), toMap(chatCompletion.id(), choice))
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
.toList();

return new ChatResponse(generations, DeepSeekChatResponseMetadata.from(completionEntity.getBody()));
});
}

private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
Map<String, Object> map = new HashMap<>();

var message = choice.message();
if (message.role() != null) {
map.put("role", message.role().name());
}
if (choice.finishReason() != null) {
map.put("finishReason", choice.finishReason().name());
}
map.put("id", id);
return map;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, true);
return retryTemplate.execute(ctx -> {
var completionChunks = this.deepSeekApi.chatCompletionStream(request);
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

return completionChunks.map(this::chunkToChatCompletion).map(chatCompletion -> {
String id = chatCompletion.id();

List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generation = new Generation(choice.message().content(),
Map.of("id", id, "role", roleMap.get(id), "finishReason", finish));
if (choice.finishReason() != null) {
generation = generation
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
}
return generation;
}).toList();
return new ChatResponse(generations);
});
});
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
* @return the ChatCompletion
*/
private DeepSeekApi.ChatCompletion chunkToChatCompletion(DeepSeekApi.ChatCompletionChunk chunk) {
List<Choice> choices = chunk.choices()
.stream()
.map(cc -> new Choice(cc.finishReason(), cc.index(), cc.delta(), cc.logprobs()))
.toList();

return new DeepSeekApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
chunk.systemFingerprint(), "chat.completion", null);
}

protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
return this.deepSeekApi.chatCompletionEntity(request);
}

/**
* Accessible for testing.
*/
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
.stream()
.map(m -> new ChatCompletionMessage(m.getContent(), Role.valueOf(m.getMessageType().name())))
.toList();

ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
DeepSeekChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, DeepSeekChatOptions.class);

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
}
return request;
}

}
Loading

0 comments on commit cba666c

Please sign in to comment.