Skip to content

Commit

Permalink
modify message type get value
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeCP17 authored and markpollack committed Sep 27, 2023
1 parent cffa790 commit 5ff2c3b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public AiResponse generate(Prompt prompt) {
List<Message> messages = prompt.getMessages();
List<ChatMessage> azureMessages = new ArrayList<>();
for (Message message : messages) {
String messageType = message.getMessageType().getValue();
String messageType = message.getMessageTypeValue();
ChatRole chatRole = ChatRole.fromString(messageType);
azureMessages.add(new ChatMessage(chatRole, message.getContent()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,9 @@ public MessageType getMessageType() {
return this.messageType;
}

@Override
public String getMessageTypeValue() {
return this.messageType.getValue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ public interface Message {

MessageType getMessageType();

String getMessageTypeValue();

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.client.Generation;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.MessageType;
import org.springframework.util.Assert;

import java.util.ArrayList;
Expand Down Expand Up @@ -79,7 +80,7 @@ public AiResponse generate(Prompt prompt) {
List<Message> messages = prompt.getMessages();
List<ChatMessage> theoMessages = new ArrayList<>();
for (Message message : messages) {
String messageType = message.getMessageType().getValue();
String messageType = message.getMessageTypeValue();
theoMessages.add(new ChatMessage(messageType, message.getContent()));
}
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
Expand Down Expand Up @@ -148,14 +149,14 @@ private List<ChatMessage> convertToChatMessages(List<Message> messages) {
for (Message promptMessage : messages) {
switch (promptMessage.getMessageType()) {
case USER:
chatMessages.add(new ChatMessage("user", promptMessage.getContent()));
chatMessages.add(new ChatMessage(MessageType.USER.getValue(), promptMessage.getContent()));
break;
case ASSISTANT:
// TODO - valid?
chatMessages.add(new ChatMessage("assistant", promptMessage.getContent()));
chatMessages.add(new ChatMessage(MessageType.ASSISTANT.getValue(), promptMessage.getContent()));
break;
case SYSTEM:
chatMessages.add(new ChatMessage("system", promptMessage.getContent()));
chatMessages.add(new ChatMessage(MessageType.SYSTEM.getValue(), promptMessage.getContent()));
break;
case FUNCTION:
logger.error(
Expand Down

0 comments on commit 5ff2c3b

Please sign in to comment.