-
Notifications
You must be signed in to change notification settings - Fork 843
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ChatClient support for returning ResponseEntity<ChatResponse, T>
ChatClient already provides the .chatResponse() method to return the entire ChatResponse instance. It also provides a set of overloaded .entity(Type) methods to provide Type-converted responses. The new .responseEntity(Type) method returns a ResponseEntity<ChatResponse, T> instance, encapsulating both the ChatResponse and the requested Type-converted response entity. This change allows for more flexibility when handling different response types and facilitates easier integration with other components that expect ResponseEntity instances.
- Loading branch information
Showing
4 changed files
with
206 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright 2024-2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.springframework.ai.chat.client; | ||
|
||
/** | ||
* Represents a {@link org.springframework.ai.model.Model} response that includes the | ||
* entire response along withe specified response entity type. | ||
* | ||
* @param <R> the entire response type. | ||
* @param <E> the converted entity type. | ||
* @author Christian Tzolov | ||
* @since 1.0.0 | ||
*/ | ||
public record ResponseEntity<R, E>(R response, E entity) { | ||
|
||
public R getResponse() { | ||
return this.response; | ||
} | ||
|
||
public E getEntity() { | ||
return this.entity; | ||
} | ||
} |
141 changes: 141 additions & 0 deletions
141
...-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
/* | ||
* Copyright 2024-2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.springframework.ai.chat.client; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Captor; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
|
||
import org.springframework.ai.chat.messages.Message; | ||
import org.springframework.ai.chat.messages.MessageType; | ||
import org.springframework.ai.chat.metadata.ChatResponseMetadata; | ||
import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata; | ||
import org.springframework.ai.chat.model.ChatModel; | ||
import org.springframework.ai.chat.model.ChatResponse; | ||
import org.springframework.ai.chat.model.Generation; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.converter.MapOutputConverter; | ||
import org.springframework.core.ParameterizedTypeReference; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.mockito.Mockito.when; | ||
|
||
/** | ||
* @author Christian Tzolov | ||
*/ | ||
@ExtendWith(MockitoExtension.class) | ||
public class ChatClientResponseEntityTests { | ||
|
||
@Mock | ||
ChatModel chatModel; | ||
|
||
@Captor | ||
ArgumentCaptor<Prompt> promptCaptor; | ||
|
||
record MyBean(String name, int age) { | ||
} | ||
|
||
@Test | ||
public void responseEntityTest() { | ||
|
||
ChatResponseMetadata metadata = new DefaultChatResponseMetadata(); | ||
metadata.put("key1", "value1"); | ||
|
||
var chatResponse = new ChatResponse(List.of(new Generation(""" | ||
{"name":"John", "age":30} | ||
""")), metadata); | ||
|
||
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); | ||
|
||
ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(chatModel) | ||
.build() | ||
.prompt() | ||
.user("Tell me about John") | ||
.call() | ||
.responseEntity(MyBean.class); | ||
|
||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); | ||
assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1"); | ||
|
||
assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30)); | ||
|
||
Message userMessage = promptCaptor.getValue().getInstructions().get(0); | ||
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); | ||
assertThat(userMessage.getContent()).contains("Tell me about John"); | ||
} | ||
|
||
@Test | ||
public void parametrizedResponseEntityTest() { | ||
|
||
var chatResponse = new ChatResponse(List.of(new Generation(""" | ||
[ | ||
{"name":"Max", "age":10}, | ||
{"name":"Adi", "age":13} | ||
] | ||
"""))); | ||
|
||
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); | ||
|
||
ResponseEntity<ChatResponse, List<MyBean>> responseEntity = ChatClient.builder(chatModel) | ||
.build() | ||
.prompt() | ||
.user("Tell me about them") | ||
.call() | ||
.responseEntity(new ParameterizedTypeReference<List<MyBean>>() { | ||
}); | ||
|
||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); | ||
assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10)); | ||
assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13)); | ||
|
||
Message userMessage = promptCaptor.getValue().getInstructions().get(0); | ||
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); | ||
assertThat(userMessage.getContent()).contains("Tell me about them"); | ||
} | ||
|
||
@Test | ||
public void customSoCResponseEntityTest() { | ||
|
||
var chatResponse = new ChatResponse(List.of(new Generation(""" | ||
{"name":"Max", "age":10}, | ||
"""))); | ||
|
||
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); | ||
|
||
ResponseEntity<ChatResponse, Map<String, Object>> responseEntity = ChatClient.builder(chatModel) | ||
.build() | ||
.prompt() | ||
.user("Tell me about Max") | ||
.call() | ||
.responseEntity(new MapOutputConverter()); | ||
|
||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); | ||
assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max"); | ||
assertThat(responseEntity.getEntity().get("age")).isEqualTo(10); | ||
|
||
Message userMessage = promptCaptor.getValue().getInstructions().get(0); | ||
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); | ||
assertThat(userMessage.getContent()).contains("Tell me about Max"); | ||
} | ||
|
||
} |