Skip to content

Commit

Permalink
Fix the ChatClient call().content() output
Browse files Browse the repository at this point in the history
  • Loading branch information
tzolov committed Jun 5, 2024
1 parent 73b445b commit 64378df
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
Expand Down Expand Up @@ -139,6 +140,11 @@ public Builder withModel(String model) {
return this;
}

public Builder withModel(MistralAiApi.ChatModel chatModel) {
this.options.setModel(chatModel.getModelName());
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
this.options.setMaxTokens(maxTokens);
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.mistralai;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.Resource;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = MistralAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
class MistralAiChatClientIT {

private static final Logger logger = LoggerFactory.getLogger(MistralAiChatClientIT.class);

@Autowired
MistralAiChatModel chatModel;

@Value("classpath:/prompts/system-message.st")
private Resource systemTextResource;

record ActorsFilms(String actor, List<String> movies) {
}

@Test
void call() {
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.system(s -> s.text(systemTextResource)
.param("name", "Bob")
.param("voice", "pirate"))
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.call()
.chatResponse();
// @formatter:on

logger.info("" + response);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
}

@Test
void listOutputConverterString() {
// @formatter:off
List<String> collection = ChatClient.create(chatModel).prompt()
.user(u -> u.text("List five {subject}")
.param("subject", "ice cream flavors"))
.call()
.entity(new ParameterizedTypeReference<List<String>>() {});
// @formatter:on

logger.info(collection.toString());
assertThat(collection).hasSize(5);
}

@Test
void listOutputConverterBean() {

// @formatter:off
List<ActorsFilms> actorsFilms = ChatClient.create(chatModel).prompt()
.user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
.call()
.entity(new ParameterizedTypeReference<List<ActorsFilms>>() {
});
// @formatter:on

logger.info("" + actorsFilms);
assertThat(actorsFilms).hasSize(2);
}

@Test
void customOutputConverter() {

var toStringListConverter = new ListOutputConverter(new DefaultConversionService());

// @formatter:off
List<String> flavors = ChatClient.create(chatModel).prompt()
.user(u -> u.text("List 10 {subject}")
.param("subject", "ice cream flavors"))
.call()
.entity(toStringListConverter);
// @formatter:on

logger.info("ice cream flavors" + flavors);
assertThat(flavors).hasSize(10);
assertThat(flavors).containsAnyOf("Vanilla", "vanilla");
}

@Test
void mapOutputConverter() {
// @formatter:off
Map<String, Object> result = ChatClient.create(chatModel).prompt()
.user(u -> u.text("Provide me a List of {subject}")
.param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'"))
.call()
.entity(new ParameterizedTypeReference<Map<String, Object>>() {
});
// @formatter:on

assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
}

@Test
void beanOutputConverter() {

// @formatter:off
ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt()
.user("Generate the filmography for a random actor.")
.call()
.entity(ActorsFilms.class);
// @formatter:on

logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isNotBlank();
}

@Test
void beanOutputConverterRecords() {

// @formatter:off
ActorsFilms actorsFilms = ChatClient.create(chatModel).prompt()
.user("Generate the filmography of 5 movies for Tom Hanks.")
.call()
.entity(ActorsFilms.class);
// @formatter:on

logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}

@Test
void beanStreamOutputConverterRecords() {

BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);

// @formatter:off
Flux<String> chatResponse = ChatClient.create(chatModel)
.prompt()
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.stream()
.content();

String generationTextFromStream = chatResponse.collectList()
.block()
.stream()
.collect(Collectors.joining());
// @formatter:on

ActorsFilms actorsFilms = outputConverter.convert(generationTextFromStream);

logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}

@Test
void functionCallTest() {

// @formatter:off
String response = ChatClient.create(chatModel).prompt()
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build())
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);

assertThat(response).containsAnyOf("30.0", "30");
assertThat(response).containsAnyOf("10.0", "10");
assertThat(response).containsAnyOf("15.0", "15");
}

@Test
void defaultFunctionCallTest() {

// @formatter:off
String response = ChatClient.builder(chatModel)
.defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
.build()
.prompt().call().content();
// @formatter:on

logger.info("Response: {}", response);

assertThat(response).containsAnyOf("30.0", "30");
assertThat(response).containsAnyOf("10.0", "10");
assertThat(response).containsAnyOf("15.0", "15");
}

@Test
void streamFunctionCallTest() {

// @formatter:off
Flux<String> response = ChatClient.create(chatModel).prompt()
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.stream()
.content();
// @formatter:on

String content = response.collectList().block().stream().collect(Collectors.joining());
logger.info("Response: {}", content);

assertThat(content).containsAnyOf("30.0", "30");
assertThat(content).containsAnyOf("10.0", "10");
assertThat(content).containsAnyOf("15.0", "15");
}

}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ void multiModalityImageUrl(String modelName) throws IOException {
// @formatter:off
String response = ChatClient.create(chatModel).prompt()
// TODO consider adding model(...) method to ChatClient as a shortcut to
// OpenAiChatOptions.builder().withModel(modelName).build()
.options(OpenAiChatOptions.builder().withModel(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
.call()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ interface PromptUserSpec {

PromptUserSpec media(MimeType mimeType, Resource resource);

List<Media> media();

}

interface PromptSystemSpec {
Expand Down Expand Up @@ -228,8 +226,6 @@ interface Builder {

Builder defaultAdvisors(List<RequestResponseAdvisor> advisors);

ChatClient build();

Builder defaultOptions(ChatOptions chatOptions);

Builder defaultUser(String text);
Expand All @@ -252,51 +248,8 @@ interface Builder {

Builder defaultFunctions(String... functionNames);

}

/**
* Calls the underlying chat model with a prompt message and returns the output
* content of the first generation.
* @param message The message to be used as a prompt for the chat model.
* @return The output content of the first generation.
* @deprecated This method is deprecated as of version 1.0.0 M1 and will be removed in
* a future release. Use the method
* builder(chatModel).build().prompt().user(message).call().content() instead
*
*/
@Deprecated(since = "1.0.0 M1", forRemoval = true)
default String call(String message) {
var prompt = new Prompt(new UserMessage(message));
var generation = call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
ChatClient build();

/**
* Calls the underlying chat model with a prompt message and returns the output
* content of the first generation.
* @param messages The messages to be used as a prompt for the chat model.
* @return The output content of the first generation.
* @deprecated This method is deprecated as of version 1.0.0 M1 and will be removed in
* a future release. Use the method
* builder(chatModel).build().prompt().messages(messages).call().content() instead.
*/
@Deprecated(since = "1.0.0 M1", forRemoval = true)
default String call(Message... messages) {
var prompt = new Prompt(Arrays.asList(messages));
var generation = call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}

/**
* Calls the underlying chat model with a prompt and returns the corresponding chat
* response.
* @param prompt The prompt to be used for the chat model.
* @return The chat response containing the generated messages.
* @deprecated This method is deprecated as of version 1.0.0 M1 and will be removed in
* a future release. Use the method builder(chatModel).build().prompt(prompt).call()
* instead.
*/
@Deprecated(since = "1.0.0 M1", forRemoval = true)
ChatResponse call(Prompt prompt);

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ protected Map<String, Object> params() {
return this.params;
}

@Override
public List<Media> media() {
protected List<Media> media() {
return this.media;
}

Expand Down Expand Up @@ -447,7 +446,7 @@ public Flux<String> content() {
return "";
}
return r.getResult().getOutput().getContent();
}).filter(v -> StringUtils.hasText(v));
}).filter(v -> StringUtils.hasLength(v));
}

}
Expand Down Expand Up @@ -824,15 +823,4 @@ public StreamPromptResponseSpec stream() {

}

/**
* use the new fluid DSL starting in {@link #prompt()}
* @param prompt the {@link Prompt prompt} object
* @return a {@link ChatResponse chat response}
*/
@Deprecated(forRemoval = true, since = "1.0.0 M1")
@Override
public ChatResponse call(Prompt prompt) {
return this.chatModel.call(prompt);
}

}

0 comments on commit 64378df

Please sign in to comment.