Skip to content

Commit

Permalink
Enhance MiniMax chat model compatibility and add tests
Browse files Browse the repository at this point in the history
- Add web search mode response in choice.message for enhanced
compatibility
- Implement web search mode for stream mode
- Add comprehensive unit tests for new features

Related to #1292

feat: enhance the compatibility of the minimax model and tests, related issue #1292
  • Loading branch information
mxsl-gr authored and markpollack committed Sep 9, 2024
1 parent 897a411 commit b38cbe6
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
Expand Down Expand Up @@ -57,9 +56,12 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE;

/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax}
* backed by {@link MiniMaxApi}.
Expand Down Expand Up @@ -169,12 +171,21 @@ public ChatResponse call(Prompt prompt) {

List<Generation> generations = choices.stream().map(choice -> {
// @formatter:off
// if the choice is a web search tool call, return last message of choice.messages
ChatCompletionMessage message = null;
if(choice.message() != null) {
message = choice.message();
} else if(!CollectionUtils.isEmpty(choice.messages())){
// the MiniMax web search messages result is ['user message','assistant tool call', 'tool call', 'assistant message']
// so the last message is the assistant message
message = choice.messages().get(choice.messages().size() - 1);
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"role", message != null && message.role() != null ? message.role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
return buildGeneration(message, choice.finishReason(), metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
Expand Down Expand Up @@ -224,7 +235,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
"role", roleMap.getOrDefault(id, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
}).filter(Objects::nonNull).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
Expand All @@ -250,12 +261,28 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
return Flux.just(response);
});
}

/**
* The MimiMax web search function tool type is 'web_search', so we need to filter out
* the tool calls whose type is not 'function'
* @param generation the generation to check
* @param toolCallFinishReasons the tool call finish reasons
* @return true if the generation is a tool call
*/
@Override
protected boolean isToolCall(Generation generation, Set<String> toolCallFinishReasons) {
if (!super.isToolCall(generation, toolCallFinishReasons)) {
return false;
}
return generation.getOutput()
.getToolCalls()
.stream()
.anyMatch(toolCall -> TOOL_CALL_FUNCTION_TYPE.equals(toolCall.type()));
}

private ChatResponseMetadata from(ChatCompletion result, RateLimit rateLimit) {
Assert.notNull(result, "MiniMax ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
Expand All @@ -277,21 +304,28 @@ private ChatResponseMetadata from(ChatCompletion result) {
.build();
}

private static Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
: choice.message()
.toolCalls()
private Generation buildGeneration(ChatCompletionMessage message, ChatCompletionFinishReason completionFinishReason,
Map<String, Object> metadata) {
if (message == null || message.role() == Role.TOOL) {
return null;
}
List<AssistantMessage.ToolCall> toolCalls = message.toolCalls() == null ? List.of()
: message.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(),
toolCall.function().name(), toolCall.function().arguments()))
.toList();

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls);
String finishReason = (completionFinishReason != null ? completionFinishReason.name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
return new Generation(assistantMessage, generationMetadata);
}

private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
return buildGeneration(choice.message(), choice.finishReason(), metadata);
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ public final class MiniMaxApiConstants {

public static final String DEFAULT_BASE_URL = "https://api.minimax.chat";

public static final String TOOL_CALL_FUNCTION_TYPE = "function";

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.api.MiniMaxApi;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat;

/**
* @author Geng Rong
*/
@EnabledIfEnvironmentVariable(named = "MINIMAX_API_KEY", matches = ".+")
public class MiniMaxChatOptionsTests {

private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatOptionsTests.class);

private final MiniMaxChatModel chatModel = new MiniMaxChatModel(new MiniMaxApi(System.getenv("MINIMAX_API_KEY")));

@Test
Expand All @@ -46,4 +56,72 @@ void testMarkSensitiveInfo() {
assertThat(unmaskResponseContent).contains("133-12345678");
}

/**
* There is a certain probability of failure, because it needs to be searched through
* the network, which may cause the test to fail due to different search results. And
* the search results are related to time. For example, after the start of the Paris
* Paralympic Games, searching for the number of gold medals in the Paris Olympics may
* be affected by the search results of the number of gold medals in the Paris
* Paralympic Games with higher priority by the search engine. Even if the input is an
* English question, there may be get Chinese content, because the main training
* content of MiniMax and search engine are Chinese
*/
@Test
void testWebSearch() {
UserMessage userMessage = new UserMessage(
"How many gold medals has the United States won in total at the 2024 Olympics?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

List<MiniMaxApi.FunctionTool> functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool());

MiniMaxChatOptions options = MiniMaxChatOptions.builder()
.withModel(ABAB_6_5_S_Chat.value)
.withTools(functionTool)
.build();

ChatResponse response = chatModel.call(new Prompt(messages, options));
String responseContent = response.getResult().getOutput().getContent();

assertThat(responseContent).contains("40");
}

/**
* There is a certain probability of failure, because it needs to be searched through
* the network, which may cause the test to fail due to different search results. And
* the search results are related to time. For example, after the start of the Paris
* Paralympic Games, searching for the number of gold medals in the Paris Olympics may
* be affected by the search results of the number of gold medals in the Paris
* Paralympic Games with higher priority by the search engine. Even if the input is an
* English question, there may be get Chinese content, because the main training
* content of MiniMax and search engine of MiniMax are Chinese
*/
@Test
void testWebSearchStream() {
UserMessage userMessage = new UserMessage(
"How many gold medals has the United States won in total at the 2024 Olympics?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

List<MiniMaxApi.FunctionTool> functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool());

MiniMaxChatOptions options = MiniMaxChatOptions.builder()
.withModel(ABAB_6_5_S_Chat.value)
.withTools(functionTool)
.build();

Flux<ChatResponse> response = chatModel.stream(new Prompt(messages, options));
String content = Objects.requireNonNull(response.collectList().block())
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.filter(Objects::nonNull)
.collect(Collectors.joining());
logger.info("Response: {}", content);

assertThat(content).contains("40");
}

}

0 comments on commit b38cbe6

Please sign in to comment.