Skip to content

Commit

Permalink
Add refusal field to ChatCompletionMessage and related classes
Browse files Browse the repository at this point in the history
- Updated OpenAiChatModel, OpenAiApi, and OpenAiStreamFunctionCallingHelper to include the `refusal` field in metadata.
- Adjusted constructors and methods to handle the new `refusal` attribute.
- Modified related tests to account for the new `refusal` field.
- Add the refusal field value to the Spring AI AssistantMessage metadata

Resolves #1178
  • Loading branch information
TarasVovk669 authored and tzolov committed Aug 8, 2024
1 parent e2c5208 commit 866b262
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,13 @@ public ChatResponse call(Prompt prompt) {

List<Generation> generations = choices.stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

Expand Down Expand Up @@ -313,7 +314,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");

return buildGeneration(choice, metadata);
}).toList();
Expand Down Expand Up @@ -453,7 +455,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
}).toList();
}
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand All @@ -466,7 +468,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null))
tr.id(), null, null))
.toList();
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,8 @@ public record ChatCompletionMessage(// @formatter:off
@JsonProperty("role") Role role,
@JsonProperty("name") String name,
@JsonProperty("tool_call_id") String toolCallId,
@JsonProperty("tool_calls") List<ToolCall> toolCalls) {// @formatter:on
@JsonProperty("tool_calls") List<ToolCall> toolCalls,
@JsonProperty("refusal") String refusal) {// @formatter:on

/**
* Get message content as String.
Expand All @@ -582,7 +583,7 @@ public String content() {
* @param role The role of the author of this message.
*/
public ChatCompletionMessage(Object content, Role role) {
this(content, role, null, null, null);
this(content, role, null, null, null, null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
String name = (current.name() != null ? current.name() : previous.name());
String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId());
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());

List<ToolCall> toolCalls = new ArrayList<>();
ToolCall lastPreviousTooCall = null;
Expand Down Expand Up @@ -120,7 +121,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
toolCalls.add(lastPreviousTooCall);
}
}
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls);
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal);
}

private ToolCall merge(ToolCall previous, ToolCall current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public void toolFunctionCall() {

// extend conversation with function response.
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(),
Role.TOOL, functionName, toolCall.id(), null));
Role.TOOL, functionName, toolCall.id(), null, null));
}
}

Expand Down

0 comments on commit 866b262

Please sign in to comment.