From 9f32a3ca9cd119116b112e050b942e5e94855914 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 30 Apr 2024 14:34:53 +0300 Subject: [PATCH] Add doc waring about gemini pro model function calling function calling degradation --- .../gemini/function/MockWeatherService.java | 5 +- ...exAiGeminiChatClientFunctionCallingIT.java | 118 +++++++----------- .../vertexai-gemini-chat-functions.adoc | 4 + .../pages/api/chat/vertexai-gemini-chat.adoc | 4 + 4 files changed, 56 insertions(+), 75 deletions(-) diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java index a77a8f86a9..ff62411a98 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java @@ -70,8 +70,7 @@ private Unit(String text) { /** * Weather Function response. */ - public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity, - Unit unit) { + public record Response(double temp, Unit unit) { } @Override @@ -89,7 +88,7 @@ else if (request.location().contains("San Francisco")) { } logger.info("Request is {}, response temperature is {}", request, temperature); - return new Response(temperature, 15, 20, 2, 53, 45, Unit.C); + return new Response(temperature, Unit.C); } } \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatClientFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatClientFunctionCallingIT.java index 980d6644cb..4b424db2c5 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatClientFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatClientFunctionCallingIT.java @@ -15,13 +15,21 @@ */ package org.springframework.ai.vertexai.gemini.function; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + import com.google.cloud.vertexai.Transport; import com.google.cloud.vertexai.VertexAI; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.AssistantMessage; @@ -36,15 +44,8 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import reactor.core.publisher.Flux; - -import java.util.ArrayList; -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertNotNull; @SpringBootTest @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @@ -67,10 +68,14 @@ public void afterEach() { } @Test + @Disabled("Google Vertex AI degraded support for parallel function calls") public void functionCallExplicitOpenApiSchema() { UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations."); + "What's the weather like in San Francisco, in Paris and in Tokyo, Japan?" + + " Use Celsius units. Answer for all requested locations."); + // " Use Celsius units. Use Multi-turn function calling. Provide answer for all + // requested locations."); List messages = new ArrayList<>(List.of(userMessage)); @@ -95,7 +100,7 @@ public void functionCallExplicitOpenApiSchema() { var promptOptions = VertexAiGeminiChatOptions.builder() .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue()) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withName("getCurrentWeather") + .withName("get_current_weather") .withDescription("Get the current weather in a given location") .withInputTypeSchema(openApiSchema) .build())) @@ -115,39 +120,48 @@ public void functionCallExplicitOpenApiSchema() { @Test public void functionCallTestInferredOpenApiSchema() { - // UserMessage userMessage = new UserMessage("What's the weather like in San - // Francisco, Paris and Tokyo?"); - UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius units."); List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = VertexAiGeminiChatOptions.builder() .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue()) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) - .withSchemaType(SchemaType.OPEN_API_SCHEMA) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .build())) + .withFunctionCallbacks(List.of( + FunctionCallbackWrapper.builder(new MockWeatherService()) + .withSchemaType(SchemaType.OPEN_API_SCHEMA) + .withName("get_current_weather") + .withDescription("Get the current weather in a given location.") + .build(), + FunctionCallbackWrapper.builder(new PaymentStatus()) + .withSchemaType(SchemaType.OPEN_API_SCHEMA) + .withName("get_payment_status") + .withDescription( + "Retrieves the payment status for transaction. For example what is the payment status for transaction 700?") + .build())) .build(); ChatResponse response = vertexGeminiClient.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); - // System.out.println(response.getResult().getOutput().getContent()); - // assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", - // "30"); - // assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", - // "10"); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); + ChatResponse response2 = vertexGeminiClient + .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); + + logger.info("Response: {}", response2); + + assertThat(response2.getResult().getOutput().getContent()).containsIgnoringCase("transaction 696 is PAYED"); + } @Test public void functionCallTestInferredOpenApiSchemaStream() { - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations."); + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco in Celsius units?"); + // UserMessage userMessage = new UserMessage( + // "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use + // Multi-turn function calling. Provide answer for all requested locations."); List messages = new ArrayList<>(List.of(userMessage)); @@ -173,63 +187,23 @@ public void functionCallTestInferredOpenApiSchemaStream() { logger.info("Response: {}", responseString); - assertThat(responseString).containsAnyOf("15.0", "15"); + // assertThat(responseString).containsAnyOf("15.0", "15"); assertThat(responseString).containsAnyOf("30.0", "30"); - assertThat(responseString).containsAnyOf("10.0", "10"); + // assertThat(responseString).containsAnyOf("10.0", "10"); } - // Gemini wants single tool with multiple function, instead multiple tools with single - // function - @Test - public void canDeclareMultipleFunctions() { - - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations."); - - List messages = new ArrayList<>(List.of(userMessage)); - - final var weatherFunction = FunctionCallbackWrapper.builder(new MockWeatherService()) - .withSchemaType(SchemaType.OPEN_API_SCHEMA) - .withName("getCurrentWeather") - .withDescription("Get the current weather in a given location") - .build(); - final var theAnswer = FunctionCallbackWrapper.builder(new TheAnswerMock()) - .withSchemaType(SchemaType.OPEN_API_SCHEMA) - .withName("theAnswerToTheUniverse") - .withDescription("the answer to the ultimate question of life, the universe, and everything") - .build(); - var promptOptions = VertexAiGeminiChatOptions.builder() - .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue()) - .withFunctionCallbacks(List.of(weatherFunction)) - .build(); - // var promptOptions = VertexAiGeminiChatOptions.builder() - // .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue()) - // .withFunctionCallbacks(List.of(weatherFunction, theAnswer)) - // .build(); - - ChatResponse response = vertexGeminiClient.call(new Prompt(messages, promptOptions)); - - String responseString = response.getResult().getOutput().getContent(); - - logger.info("Response: {}", responseString); - assertNotNull(responseString); - - response = vertexGeminiClient - .call(new Prompt("What is the answer of the ultimate question in life?", promptOptions)); - - responseString = response.getResult().getOutput().getContent(); - - logger.info("Response: {}", responseString); - assertNotNull(responseString); + public record PaymentInfoRequest(String id) { + } + public record TransactionStatus(String status) { } - public static class TheAnswerMock implements Function { + public static class PaymentStatus implements Function { @Override - public Integer apply(String s) { - return 42; + public TransactionStatus apply(PaymentInfoRequest paymentInfoRequest) { + return new TransactionStatus("Transaction " + paymentInfoRequest.id() + " is PAYED"); } } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc index bc8b1be48d..a464ccf519 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc @@ -1,5 +1,9 @@ = Gemini Function Calling +WARNING: As of 30th of April 2023, the Vertex AI `Gemini Pro` model has significantly degraded the support for function calling! While the feature is still available, it is not recommended for production use. +Apparently the Gemini Pro can not handle anymore the function name correctly. +The parallel function calling is gone as well. + Function calling lets developers create a description of a function in their code, then pass that description to a language model in a request. The response from the model includes the name of a function that matches the description and the arguments to call it with. You can register custom Java functions with the `VertexAiGeminiChatClient` and have the Gemini Pro model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 79c1435010..3b4b8944e8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -104,6 +104,10 @@ TIP: In addition to the model specific `VertexAiChatPaLm2Options` you can use a == Function Calling +WARNING: As of 30th of April 2023, the Vertex AI `Gemini Pro` model has significantly degraded the support for function calling! While the feature is still available, it is not recommended for production use. +Apparently the Gemini Pro can not handle anymore the function name correctly. +The parallel function calling is gone as well. + You can register custom Java functions with the VertexAiGeminiChatClient and have the Gemini Pro model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. This is a powerful technique to connect the LLM capabilities with external tools and APIs. Read more about xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Vertex AI Gemini Function Calling].