Skip to content

Commit

Permalink
Add doc waring about gemini pro model function calling function calli…
Browse files Browse the repository at this point in the history
…ng degradation
  • Loading branch information
tzolov committed Apr 30, 2024
1 parent 9cd01c5 commit 9f32a3c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ private Unit(String text) {
/**
* Weather Function response.
*/
public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity,
Unit unit) {
public record Response(double temp, Unit unit) {
}

@Override
Expand All @@ -89,7 +88,7 @@ else if (request.location().contains("San Francisco")) {
}

logger.info("Request is {}, response temperature is {}", request, temperature);
return new Response(temperature, 15, 20, 2, 53, 45, Unit.C);
return new Response(temperature, Unit.C);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@
*/
package org.springframework.ai.vertexai.gemini.function;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -36,15 +44,8 @@
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;

@SpringBootTest
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
Expand All @@ -67,10 +68,14 @@ public void afterEach() {
}

@Test
@Disabled("Google Vertex AI degraded support for parallel function calls")
public void functionCallExplicitOpenApiSchema() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan?"
+ " Use Celsius units. Answer for all requested locations.");
// " Use Celsius units. Use Multi-turn function calling. Provide answer for all
// requested locations.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand All @@ -95,7 +100,7 @@ public void functionCallExplicitOpenApiSchema() {
var promptOptions = VertexAiGeminiChatOptions.builder()
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withName("get_current_weather")
.withDescription("Get the current weather in a given location")
.withInputTypeSchema(openApiSchema)
.build()))
Expand All @@ -115,39 +120,48 @@ public void functionCallExplicitOpenApiSchema() {
@Test
public void functionCallTestInferredOpenApiSchema() {

// UserMessage userMessage = new UserMessage("What's the weather like in San
// Francisco, Paris and Tokyo?");
UserMessage userMessage = new UserMessage("What's the weather like in Paris?");
UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius units.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = VertexAiGeminiChatOptions.builder()
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.build()))
.withFunctionCallbacks(List.of(
FunctionCallbackWrapper.builder(new MockWeatherService())
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
.withName("get_current_weather")
.withDescription("Get the current weather in a given location.")
.build(),
FunctionCallbackWrapper.builder(new PaymentStatus())
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
.withName("get_payment_status")
.withDescription(
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
.build()))
.build();

ChatResponse response = vertexGeminiClient.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);

// System.out.println(response.getResult().getOutput().getContent());
// assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0",
// "30");
// assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0",
// "10");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");

ChatResponse response2 = vertexGeminiClient
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));

logger.info("Response: {}", response2);

assertThat(response2.getResult().getOutput().getContent()).containsIgnoringCase("transaction 696 is PAYED");

}

@Test
public void functionCallTestInferredOpenApiSchemaStream() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco in Celsius units?");
// UserMessage userMessage = new UserMessage(
// "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use
// Multi-turn function calling. Provide answer for all requested locations.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand All @@ -173,63 +187,23 @@ public void functionCallTestInferredOpenApiSchemaStream() {

logger.info("Response: {}", responseString);

assertThat(responseString).containsAnyOf("15.0", "15");
// assertThat(responseString).containsAnyOf("15.0", "15");
assertThat(responseString).containsAnyOf("30.0", "30");
assertThat(responseString).containsAnyOf("10.0", "10");
// assertThat(responseString).containsAnyOf("10.0", "10");

}

// Gemini wants single tool with multiple function, instead multiple tools with single
// function
@Test
public void canDeclareMultipleFunctions() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

final var weatherFunction = FunctionCallbackWrapper.builder(new MockWeatherService())
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.build();
final var theAnswer = FunctionCallbackWrapper.builder(new TheAnswerMock())
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
.withName("theAnswerToTheUniverse")
.withDescription("the answer to the ultimate question of life, the universe, and everything")
.build();
var promptOptions = VertexAiGeminiChatOptions.builder()
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
.withFunctionCallbacks(List.of(weatherFunction))
.build();
// var promptOptions = VertexAiGeminiChatOptions.builder()
// .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
// .withFunctionCallbacks(List.of(weatherFunction, theAnswer))
// .build();

ChatResponse response = vertexGeminiClient.call(new Prompt(messages, promptOptions));

String responseString = response.getResult().getOutput().getContent();

logger.info("Response: {}", responseString);
assertNotNull(responseString);

response = vertexGeminiClient
.call(new Prompt("What is the answer of the ultimate question in life?", promptOptions));

responseString = response.getResult().getOutput().getContent();

logger.info("Response: {}", responseString);
assertNotNull(responseString);
public record PaymentInfoRequest(String id) {
}

public record TransactionStatus(String status) {
}

public static class TheAnswerMock implements Function<String, Integer> {
public static class PaymentStatus implements Function<PaymentInfoRequest, TransactionStatus> {

@Override
public Integer apply(String s) {
return 42;
public TransactionStatus apply(PaymentInfoRequest paymentInfoRequest) {
return new TransactionStatus("Transaction " + paymentInfoRequest.id() + " is PAYED");
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
= Gemini Function Calling

WARNING: As of 30th of April 2023, the Vertex AI `Gemini Pro` model has significantly degraded the support for function calling! While the feature is still available, it is not recommended for production use.
Apparently the Gemini Pro can not handle anymore the function name correctly.
The parallel function calling is gone as well.

Function calling lets developers create a description of a function in their code, then pass that description to a language model in a request. The response from the model includes the name of a function that matches the description and the arguments to call it with.

You can register custom Java functions with the `VertexAiGeminiChatClient` and have the Gemini Pro model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ TIP: In addition to the model specific `VertexAiChatPaLm2Options` you can use a

== Function Calling

WARNING: As of 30th of April 2023, the Vertex AI `Gemini Pro` model has significantly degraded the support for function calling! While the feature is still available, it is not recommended for production use.
Apparently the Gemini Pro can not handle anymore the function name correctly.
The parallel function calling is gone as well.

You can register custom Java functions with the VertexAiGeminiChatClient and have the Gemini Pro model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions.
This is a powerful technique to connect the LLM capabilities with external tools and APIs.
Read more about xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Vertex AI Gemini Function Calling].
Expand Down

0 comments on commit 9f32a3c

Please sign in to comment.