Skip to content

Commit

Permalink
Add ChatClient support for returning ResponseEntity<ChatResponse, T>
Browse files Browse the repository at this point in the history
 ChatClient already provides the .chatResponse() method to return the entire ChatResponse instance.
 It also provides a set of overloaded .entity(Type) methods to provide Type-converted responses.
 The new .responseEntity(Type) method returns a ResponseEntity<ChatResponse, T> instance, encapsulating
 both the ChatResponse and the requested Type-converted response entity.

 This change allows for more flexibility when handling different response types and facilitates
 easier integration with other components that expect ResponseEntity instances.
  • Loading branch information
tzolov committed Jun 4, 2024
1 parent 6ad36b7 commit d6a0dff
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ interface CallResponseSpec {

String content();

<T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type);

<T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type);

<T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter);

}

interface StreamResponseSpec {
Expand Down Expand Up @@ -205,9 +211,6 @@ <I, O> ChatClientRequestSpec function(String name, String description,

ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);

// ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest,
// Map<String, Object> context);

CallResponseSpec call();

StreamResponseSpec stream();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,28 @@ public DefaultCallResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec
this.request = request;
}

public <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) {
Assert.notNull(type, "the class must be non-null");
return doResponseEntity(new BeanOutputConverter<T>(type));
}

public <T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type) {
return doResponseEntity(new BeanOutputConverter<T>(type));
}

public <T> ResponseEntity<ChatResponse, T> responseEntity(
StructuredOutputConverter<T> structuredOutputConverter) {
return doResponseEntity(structuredOutputConverter);
}

protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> boc) {
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
var responseContent = chatResponse.getResult().getOutput().getContent();
T entity = boc.convert(responseContent);

return new ResponseEntity<>(chatResponse, entity);
}

public <T> T entity(ParameterizedTypeReference<T> type) {
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client;

/**
* Represents a {@link org.springframework.ai.model.Model} response that includes the
* entire response along withe specified response entity type.
*
* @param <R> the entire response type.
* @param <E> the converted entity type.
* @author Christian Tzolov
* @since 1.0.0
*/
public record ResponseEntity<R, E>(R response, E entity) {

public R getResponse() {
return this.response;
}

public E getEntity() {
return this.entity;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.core.ParameterizedTypeReference;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.when;

/**
* @author Christian Tzolov
*/
@ExtendWith(MockitoExtension.class)
public class ChatClientResponseEntityTests {

@Mock
ChatModel chatModel;

@Captor
ArgumentCaptor<Prompt> promptCaptor;

record MyBean(String name, int age) {
}

@Test
public void responseEntityTest() {

ChatResponseMetadata metadata = new DefaultChatResponseMetadata();
metadata.put("key1", "value1");

var chatResponse = new ChatResponse(List.of(new Generation("""
{"name":"John", "age":30}
""")), metadata);

when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);

ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(chatModel)
.build()
.prompt()
.user("Tell me about John")
.call()
.responseEntity(MyBean.class);

assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1");

assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));

Message userMessage = promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getContent()).contains("Tell me about John");
}

@Test
public void parametrizedResponseEntityTest() {

var chatResponse = new ChatResponse(List.of(new Generation("""
[
{"name":"Max", "age":10},
{"name":"Adi", "age":13}
]
""")));

when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);

ResponseEntity<ChatResponse, List<MyBean>> responseEntity = ChatClient.builder(chatModel)
.build()
.prompt()
.user("Tell me about them")
.call()
.responseEntity(new ParameterizedTypeReference<List<MyBean>>() {
});

assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10));
assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13));

Message userMessage = promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getContent()).contains("Tell me about them");
}

@Test
public void customSoCResponseEntityTest() {

var chatResponse = new ChatResponse(List.of(new Generation("""
{"name":"Max", "age":10},
""")));

when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);

ResponseEntity<ChatResponse, Map<String, Object>> responseEntity = ChatClient.builder(chatModel)
.build()
.prompt()
.user("Tell me about Max")
.call()
.responseEntity(new MapOutputConverter());

assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max");
assertThat(responseEntity.getEntity().get("age")).isEqualTo(10);

Message userMessage = promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getContent()).contains("Tell me about Max");
}

}

0 comments on commit d6a0dff

Please sign in to comment.