Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvement based MiniMax model client #744

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.minimax.api.common.MiniMaxApiException;
import org.springframework.ai.minimax.api.MiniMaxApi.*;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
Expand Down Expand Up @@ -57,15 +59,15 @@
* @since 1.0.0 M1
*/
public class MiniMaxChatClient extends
AbstractFunctionCallSupport<MiniMaxApi.ChatCompletionMessage, MiniMaxApi.ChatCompletionRequest, ResponseEntity<MiniMaxApi.ChatCompletion>>
AbstractFunctionCallSupport<ChatCompletionMessage, ChatCompletionRequest, ResponseEntity<ChatCompletion>>
implements ChatClient, StreamingChatClient {

private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatClient.class);

/**
* The default options used for the chat completion requests.
*/
private MiniMaxChatOptions defaultOptions;
private final MiniMaxChatOptions defaultOptions;

/**
* The retry template used to retry the MiniMax API calls.
Expand Down Expand Up @@ -120,11 +122,11 @@ public MiniMaxChatClient(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
@Override
public ChatResponse call(Prompt prompt) {

MiniMaxApi.ChatCompletionRequest request = createRequest(prompt, false);
ChatCompletionRequest request = createRequest(prompt, false);

return this.retryTemplate.execute(ctx -> {

ResponseEntity<MiniMaxApi.ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
Expand All @@ -133,7 +135,7 @@ public ChatResponse call(Prompt prompt) {
}

if (chatCompletion.baseResponse() != null && chatCompletion.baseResponse().statusCode() != 0) {
throw new MiniMaxApiException(chatCompletion.baseResponse().message());
throw new RuntimeException(chatCompletion.baseResponse().message());
}

List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
Expand All @@ -145,7 +147,7 @@ public ChatResponse call(Prompt prompt) {
});
}

private Map<String, Object> toMap(String id, MiniMaxApi.ChatCompletion.Choice choice) {
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
Map<String, Object> map = new HashMap<>();

var message = choice.message();
Expand All @@ -162,19 +164,19 @@ private Map<String, Object> toMap(String id, MiniMaxApi.ChatCompletion.Choice ch
@Override
public Flux<ChatResponse> stream(Prompt prompt) {

MiniMaxApi.ChatCompletionRequest request = createRequest(prompt, true);
ChatCompletionRequest request = createRequest(prompt, true);

return this.retryTemplate.execute(ctx -> {

Flux<MiniMaxApi.ChatCompletionChunk> completionChunks = this.miniMaxApi.chatCompletionStream(request);
Flux<ChatCompletionChunk> completionChunks = this.miniMaxApi.chatCompletionStream(request);

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
return completionChunks.map(chunk -> chunkToChatCompletion(chunk)).map(chatCompletion -> {
return completionChunks.map(this::chunkToChatCompletion).map(chatCompletion -> {
try {
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))
.getBody();
Expand Down Expand Up @@ -212,23 +214,23 @@ public Flux<ChatResponse> stream(Prompt prompt) {
* @param chunk the ChatCompletionChunk to convert
* @return the ChatCompletion
*/
private MiniMaxApi.ChatCompletion chunkToChatCompletion(MiniMaxApi.ChatCompletionChunk chunk) {
List<MiniMaxApi.ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> {
MiniMaxApi.ChatCompletionMessage delta = cc.delta();
private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
List<ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> {
ChatCompletionMessage delta = cc.delta();
if (delta == null) {
delta = new MiniMaxApi.ChatCompletionMessage("", MiniMaxApi.ChatCompletionMessage.Role.ASSISTANT);
delta = new ChatCompletionMessage("", Role.ASSISTANT);
}
return new MiniMaxApi.ChatCompletion.Choice(cc.finishReason(), cc.index(), delta, cc.logprobs());
return new ChatCompletion.Choice(cc.finishReason(), cc.index(), delta, cc.logprobs());
}).toList();

return new MiniMaxApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
chunk.systemFingerprint(), "chat.completion", null, null);
return new ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.systemFingerprint(),
"chat.completion", null, null);
}

/**
* Accessible for testing.
*/
MiniMaxApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

Set<String> functionsForThisRequest = new HashSet<>();

Expand All @@ -238,7 +240,7 @@ MiniMaxApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
MiniMaxApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();

MiniMaxApi.ChatCompletionRequest request = new MiniMaxApi.ChatCompletionRequest(chatCompletionMessages, stream);
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
Expand All @@ -249,8 +251,7 @@ MiniMaxApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request,
MiniMaxApi.ChatCompletionRequest.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
Expand All @@ -265,52 +266,35 @@ MiniMaxApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

functionsForThisRequest.addAll(defaultEnabledFunctions);

request = ModelOptionsUtils.merge(request, this.defaultOptions, MiniMaxApi.ChatCompletionRequest.class);
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
}

// Add the enabled functions definitions to the request's tools parameter.
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {

request = ModelOptionsUtils.merge(
MiniMaxChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(),
request, MiniMaxApi.ChatCompletionRequest.class);
request, ChatCompletionRequest.class);
}

return request;
}

private String fromMediaData(MimeType mimeType, Object mediaContentData) {
if (mediaContentData instanceof byte[] bytes) {
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
// following the prefix pattern.
return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
}
else if (mediaContentData instanceof String text) {
// Assume the text is a URLs or a base64 encoded image prefixed by the user.
return text;
}
else {
throw new IllegalArgumentException(
"Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
}
}

private List<MiniMaxApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var function = new MiniMaxApi.FunctionTool.Function(functionCallback.getDescription(),
functionCallback.getName(), functionCallback.getInputTypeSchema());
return new MiniMaxApi.FunctionTool(function);
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
functionCallback.getInputTypeSchema());
return new FunctionTool(function);
}).toList();
}

@Override
protected MiniMaxApi.ChatCompletionRequest doCreateToolResponseRequest(
MiniMaxApi.ChatCompletionRequest previousRequest, MiniMaxApi.ChatCompletionMessage responseMessage,
List<MiniMaxApi.ChatCompletionMessage> conversationHistory) {
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (MiniMaxApi.ChatCompletionMessage.ToolCall toolCall : responseMessage.toolCalls()) {
for (ToolCall toolCall : responseMessage.toolCalls()) {

var functionName = toolCall.function().name();
String functionArguments = toolCall.function().arguments();
Expand All @@ -322,42 +306,43 @@ protected MiniMaxApi.ChatCompletionRequest doCreateToolResponseRequest(
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

// Add the function response to the conversation.
conversationHistory.add(new MiniMaxApi.ChatCompletionMessage(functionResponse,
MiniMaxApi.ChatCompletionMessage.Role.TOOL, functionName, toolCall.id(), null));
conversationHistory
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
}

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
MiniMaxApi.ChatCompletionRequest newRequest = new MiniMaxApi.ChatCompletionRequest(conversationHistory, false);
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, MiniMaxApi.ChatCompletionRequest.class);
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);

return newRequest;
}

@Override
protected List<MiniMaxApi.ChatCompletionMessage> doGetUserMessages(MiniMaxApi.ChatCompletionRequest request) {
protected List<ChatCompletionMessage> doGetUserMessages(ChatCompletionRequest request) {
return request.messages();
}

@Override
protected MiniMaxApi.ChatCompletionMessage doGetToolResponseMessage(
ResponseEntity<MiniMaxApi.ChatCompletion> chatCompletion) {
protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<ChatCompletion> chatCompletion) {
return chatCompletion.getBody().choices().iterator().next().message();
}

@Override
protected ResponseEntity<MiniMaxApi.ChatCompletion> doChatCompletion(MiniMaxApi.ChatCompletionRequest request) {
protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
return this.miniMaxApi.chatCompletionEntity(request);
}

@Override
protected Flux<ResponseEntity<MiniMaxApi.ChatCompletion>> doChatCompletionStream(
MiniMaxApi.ChatCompletionRequest request) {
throw new RuntimeException("Streaming Function calling is not supported");
protected Flux<ResponseEntity<ChatCompletion>> doChatCompletionStream(ChatCompletionRequest request) {
return this.miniMaxApi.chatCompletionStream(request)
.map(this::chunkToChatCompletion)
.map(Optional::ofNullable)
.map(ResponseEntity::of);
}

@Override
protected boolean isToolFunctionCall(ResponseEntity<MiniMaxApi.ChatCompletion> chatCompletion) {
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> chatCompletion) {
var body = chatCompletion.getBody();
if (body == null) {
return false;
Expand All @@ -371,7 +356,7 @@ protected boolean isToolFunctionCall(ResponseEntity<MiniMaxApi.ChatCompletion> c
var choice = choices.get(0);
var message = choice.message();
return message != null && !CollectionUtils.isEmpty(choice.message().toolCalls())
&& choice.finishReason() == MiniMaxApi.ChatCompletionFinishReason.TOOL_CALLS;
&& choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ public static class ToolChoiceBuilder {
/**
* Specifying a particular function forces the model to call that function.
*/
public static Object FUNCTION(String functionName) {
public static Object function(String functionName) {
return Map.of("type", "function", "function", Map.of("name", functionName));
}
}
Expand Down Expand Up @@ -863,8 +863,6 @@ public record EmbeddingList(
*
* @param embeddingRequest The embedding request.
* @return Returns {@link EmbeddingList}.
*
* <pre>{@code List.of("text1", "text2", "text3") or List.of(List.of(1, 2, 3), List.of(3, 4, 5))} </pre>
*/
public ResponseEntity<EmbeddingList> embeddings(EmbeddingRequest embeddingRequest) {

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public void chatOptionsTest() {
"spring.ai.minimax.chat.options.topP=0.56",

// "spring.ai.minimax.chat.options.toolChoice.functionName=toolChoiceFunctionName",
"spring.ai.minimax.chat.options.toolChoice=" + ModelOptionsUtils.toJsonString(MiniMaxApi.ChatCompletionRequest.ToolChoiceBuilder.FUNCTION("toolChoiceFunctionName")),
"spring.ai.minimax.chat.options.toolChoice=" + ModelOptionsUtils.toJsonString(MiniMaxApi.ChatCompletionRequest.ToolChoiceBuilder.function("toolChoiceFunctionName")),

"spring.ai.minimax.chat.options.tools[0].function.name=myFunction1",
"spring.ai.minimax.chat.options.tools[0].function.description=function description",
Expand Down