From 79b44e48be6890b002bc9daa2f3a45ae59b4b19a Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 23 Sep 2024 16:08:33 +0200 Subject: [PATCH] Add proxy tool calls option to chat models This commit introduces a new proxyToolCalls option for various chat models in the Spring AI project. When enabled, it allows the client to handle function calls externally instead of being processed internally by Spring AI. The change affects multiple chat model implementations, including: AnthropicChatModel AzureOpenAiChatModel MiniMaxChatModel MistralAiChatModel MoonshotChatModel OllamaChatModel OpenAiChatModel VertexAiGeminiChatModel ZhiPuAiChatModel The proxyToolCalls option is added to the respective chat options classes and integrated into the AbstractToolCallSupport class for consistent handling across different implementations. The proxyToolCalls option can be set either programmatically via the ChatOptions.builder().withProxyToolCalls() method or the spring.ai..chat.options.proxy-tool-calls application property. Documentation for the new option is also updated in the relevant Antora pages. Resolves #1367 --- .../ai/anthropic/AnthropicChatModel.java | 5 +- .../ai/anthropic/AnthropicChatOptions.java | 18 + .../ai/azure/openai/AzureOpenAiChatModel.java | 6 +- .../azure/openai/AzureOpenAiChatOptions.java | 19 + .../ai/minimax/MiniMaxChatModel.java | 4 +- .../ai/minimax/MiniMaxChatOptions.java | 25 + .../ai/mistralai/MistralAiChatModel.java | 7 +- .../ai/mistralai/MistralAiChatOptions.java | 124 ++++ .../ai/moonshot/MoonshotChatModel.java | 8 +- .../ai/moonshot/MoonshotChatOptions.java | 24 + .../ai/ollama/OllamaChatModel.java | 3 +- .../ai/ollama/api/OllamaOptions.java | 23 +- .../ai/openai/OpeAiApiAdapter.java | 544 ++++++++++++++++++ .../ai/openai/OpenAiChatModel.java | 35 +- .../ai/openai/OpenAiChatOptions.java | 203 ++----- .../ai/openai/api/OpenAiApi.java | 7 + .../ai/openai/chat/OpenAiChatModelIT.java | 2 +- .../chat/OpenAiChatModelProxyToolCallsIT.java | 235 ++++++++ .../GroqWithOpenAiChatModelIT.java | 4 +- .../OllamaWithOpenAiChatModelIT.java | 4 +- .../gemini/VertexAiGeminiChatModel.java | 5 +- .../gemini/VertexAiGeminiChatOptions.java | 21 +- .../ai/zhipuai/ZhiPuAiChatModel.java | 4 +- .../ai/zhipuai/ZhiPuAiChatOptions.java | 25 + .../ai/chat/client/AdvisedRequest.java | 1 + .../ai/chat/client/ChatClient.java | 4 +- .../ai/chat/client/DefaultChatClient.java | 44 +- .../chat/client/DefaultChatClientBuilder.java | 5 +- .../chat/model/AbstractToolCallSupport.java | 37 +- .../ai/chat/prompt/Prompt.java | 2 +- .../function/FunctionCallingOptions.java | 10 + .../FunctionCallingOptionsBuilder.java | 17 + .../ROOT/pages/api/chat/anthropic-chat.adoc | 5 +- .../pages/api/chat/azure-openai-chat.adoc | 1 + .../ROOT/pages/api/chat/groq-chat.adoc | 1 + .../ROOT/pages/api/chat/mistralai-chat.adoc | 1 + .../ROOT/pages/api/chat/nvidia-chat.adoc | 1 + .../ROOT/pages/api/chat/ollama-chat.adoc | 1 + .../ROOT/pages/api/chat/openai-chat.adoc | 1 + .../pages/api/chat/vertexai-gemini-chat.adoc | 1 + .../ROOT/pages/api/chat/zhipuai-chat.adoc | 1 + 41 files changed, 1256 insertions(+), 232 deletions(-) create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpeAiApiAdapter.java create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/{ => proxy}/GroqWithOpenAiChatModelIT.java (98%) rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/{ => proxy}/OllamaWithOpenAiChatModelIT.java (98%) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 8a3925f3b7..dfa4026741 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -225,7 +225,8 @@ public ChatResponse call(Prompt prompt) { return chatResponse; }); - if (response != null && this.isToolCall(response, Set.of("tool_use"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null + && this.isToolCall(response, Set.of("tool_use"))) { var toolCallConversation = handleToolCalls(prompt, response); return this.call(new Prompt(toolCallConversation, prompt.getOptions())); } @@ -256,7 +257,7 @@ public Flux stream(Prompt prompt) { Flux chatResponseFlux = response.switchMap(chatCompletionResponse -> { ChatResponse chatResponse = toChatResponse(chatCompletionResponse); - if (this.isToolCall(chatResponse, Set.of("tool_use"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 03daa19524..e79a8f760f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -77,6 +77,9 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions @NestedConfigurationProperty @JsonIgnore private Set functions = new HashSet<>(); + + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on public static Builder builder() { @@ -144,6 +147,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public AnthropicChatOptions build() { return this.options; } @@ -246,6 +254,15 @@ public Double getPresencePenalty() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public AnthropicChatOptions copy() { return fromOptions(this); @@ -261,6 +278,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .withTopK(fromOptions.getTopK()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 1b2a7a7248..bffdbe2c74 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -151,7 +151,8 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = toChatResponse(chatCompletions); - if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -199,7 +200,8 @@ public Flux stream(Prompt prompt) { ChatResponse chatResponse = toChatResponse(chatCompletions); - if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, + Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index fc4c0f1795..6b85eeb966 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -31,6 +31,7 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; +import org.stringtemplate.v4.compiler.CodeGenerator.primary_return; /** * The configuration information for a chat completions request. Completions support a @@ -161,6 +162,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio @JsonIgnore private Set functions = new HashSet<>(); + @JsonIgnore + private Boolean proxyToolCalls; + public static Builder builder() { return new Builder(); } @@ -250,6 +254,11 @@ public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } @@ -395,6 +404,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public AzureOpenAiChatOptions copy() { return fromOptions(this); @@ -413,6 +431,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withUser(fromOptions.getUser()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withResponseFormat(fromOptions.getResponseFormat()) .build(); } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index ee3f758461..c14f93cbe4 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -190,7 +190,7 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - if (isToolCall(chatResponse, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message @@ -254,7 +254,7 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 31cae5791f..30426eb6c7 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -142,6 +142,9 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions { @NestedConfigurationProperty @JsonIgnore private Set functions = new HashSet<>(); + + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on public static Builder builder() { @@ -242,6 +245,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public MiniMaxChatOptions build() { return this.options; } @@ -394,6 +402,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public int hashCode() { final int prime = 31; @@ -411,6 +428,7 @@ public int hashCode() { result = prime * result + ((maskSensitiveInfo == null) ? 0 : maskSensitiveInfo.hashCode()); result = prime * result + ((tools == null) ? 0 : tools.hashCode()); result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); return result; } @@ -501,6 +519,12 @@ else if (!tools.equals(other.tools)) } else if (!toolChoice.equals(other.toolChoice)) return false; + if (this.proxyToolCalls == null) { + if (other.proxyToolCalls != null) + return false; + } + else if (!proxyToolCalls.equals(other.proxyToolCalls)) + return false; return true; } @@ -525,6 +549,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .withToolChoice(fromOptions.getToolChoice()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index ceea2e869d..edf1da14f4 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -183,8 +183,9 @@ public ChatResponse call(Prompt prompt) { return chatResponse; }); - if (response != null && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null + && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -255,7 +256,7 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index dc4fcdc6dc..7053f5ab0b 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -135,6 +135,9 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions @JsonIgnore private Set functions = new HashSet<>(); + @JsonIgnore + private Boolean proxyToolCalls; + public static Builder builder() { return new Builder(); } @@ -215,6 +218,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public MistralAiChatOptions build() { return this.options; } @@ -356,6 +364,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public MistralAiChatOptions copy() { return fromOptions(this); @@ -374,7 +391,114 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .withToolChoice(fromOptions.getToolChoice()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); + result = prime * result + ((topP == null) ? 0 : topP.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((safePrompt == null) ? 0 : safePrompt.hashCode()); + result = prime * result + ((randomSeed == null) ? 0 : randomSeed.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode()); + result = prime * result + ((functions == null) ? 0 : functions.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + MistralAiChatOptions other = (MistralAiChatOptions) obj; + if (model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (temperature == null) { + if (other.temperature != null) + return false; + } + else if (!temperature.equals(other.temperature)) + return false; + if (topP == null) { + if (other.topP != null) + return false; + } + else if (!topP.equals(other.topP)) + return false; + if (maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!maxTokens.equals(other.maxTokens)) + return false; + if (safePrompt == null) { + if (other.safePrompt != null) + return false; + } + else if (!safePrompt.equals(other.safePrompt)) + return false; + if (randomSeed == null) { + if (other.randomSeed != null) + return false; + } + else if (!randomSeed.equals(other.randomSeed)) + return false; + if (responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!responseFormat.equals(other.responseFormat)) + return false; + if (stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if (tools == null) { + if (other.tools != null) + return false; + } + else if (!tools.equals(other.tools)) + return false; + if (toolChoice != other.toolChoice) + return false; + if (functionCallbacks == null) { + if (other.functionCallbacks != null) + return false; + } + else if (!functionCallbacks.equals(other.functionCallbacks)) + return false; + if (functions == null) { + if (other.functions != null) + return false; + } + else if (!functions.equals(other.functions)) + return false; + if (proxyToolCalls == null) { + if (other.proxyToolCalls != null) + return false; + } + else if (!proxyToolCalls.equals(other.proxyToolCalls)) + return false; + return true; + } + } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 99956e81e9..553eab3713 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -164,8 +164,9 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - if (isToolCall(chatResponse, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - MoonshotApi.ChatCompletionFinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + MoonshotApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -228,7 +229,8 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 4bf51bca52..6eedc5ef1f 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -137,6 +137,9 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions */ private @JsonProperty("user") String user; + @JsonIgnore + private Boolean proxyToolCalls; + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -244,6 +247,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public MoonshotChatOptions build() { return this.options; } @@ -345,6 +353,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public MoonshotChatOptions copy() { return builder().withModel(this.model) @@ -360,6 +377,7 @@ public MoonshotChatOptions copy() { .withToolChoice(this.toolChoice) .withFunctionCallbacks(this.functionCallbacks) .withFunctions(this.functions) + .withProxyToolCalls(this.proxyToolCalls) .build(); } @@ -376,6 +394,7 @@ public int hashCode() { result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); result = prime * result + ((topP == null) ? 0 : topP.hashCode()); result = prime * result + ((user == null) ? 0 : user.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); return result; } @@ -441,6 +460,11 @@ else if (!topP.equals(other.topP)) } else if (!this.user.equals(other.user)) return false; + if (this.proxyToolCalls == null) { + return other.proxyToolCalls == null; + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + return false; return true; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 79dced2155..96e2a1267a 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -162,7 +162,8 @@ public ChatResponse call(Prompt prompt) { }); - if (response != null && isToolCall(response, Set.of("stop"))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null + && isToolCall(response, Set.of("stop"))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 3fd22d03e4..530e4361f1 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -270,7 +270,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed /** - * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. + * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. * Defaults to true. */ @JsonProperty("truncate") private Boolean truncate; @@ -297,6 +297,8 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed @JsonIgnore private Set functions = new HashSet<>(); + @JsonIgnore + private Boolean proxyToolCalls; public static OllamaOptions builder() { return new OllamaOptions(); @@ -495,6 +497,11 @@ public OllamaOptions withFunction(String functionName) { return this; } + public OllamaOptions withProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + return this; + } + // ------------------- // Getters and Setters // ------------------- @@ -816,6 +823,15 @@ public Integer getDimensions() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. @@ -884,6 +900,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .withPenalizeNewline(fromOptions.getPenalizeNewline()) .withStop(fromOptions.getStop()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()); } // @formatter:on @@ -913,7 +930,7 @@ public boolean equals(Object o) { && Objects.equals(mirostatTau, that.mirostatTau) && Objects.equals(mirostatEta, that.mirostatEta) && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions); + && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions); } @Override @@ -923,7 +940,7 @@ public int hashCode() { this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.functionCallbacks, this.functions); + this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpeAiApiAdapter.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpeAiApiAdapter.java new file mode 100644 index 0000000000..89636cb5e1 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpeAiApiAdapter.java @@ -0,0 +1,544 @@ +package org.springframework.ai.openai; + +import java.net.URL; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Date; +import java.util.List; +import java.util.Map; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.Media; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.openai.OpenAiChatOptions.Builder; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk.ChunkChoice; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; + +public class OpeAiApiAdapter { + + /** + * Helper used to provide only the function definition, without the actual function + * call implementation. + */ + public static record FunctionDefinition(String name, String description, + String inputTypeSchema) implements FunctionCallback { + + @Override + public String getName() { + return this.name(); + } + + @Override + public String getDescription() { + return this.description(); + } + + @Override + public String getInputTypeSchema() { + return this.inputTypeSchema(); + } + + @Override + public String call(String functionInput) { + throw new UnsupportedOperationException( + "FunctionDefinition provides only metadata. It doesn't implement the call method."); + } + + } + + /** + * Converts the OpenAI + * Chat + * completion request into Spring AI {@link Prompt} with + * {@link OpenAiChatOptions}. + * @param chatCompletionRequest the OpenAI Chat Completion Request to convert. + * @return the converted Spring AI Prompt. + */ + public static Prompt toPrompt(ChatCompletionRequest chatCompletionRequest) { + + // 1. Convert the Options + var chatOptionsBuilder = toChatOptions(chatCompletionRequest); + + // 2. Covert the Spring AI messages into OpenAi messages. + List apiMessages = chatCompletionRequest.messages(); + + List messages = apiMessages.stream().map(apiMessage -> { + + if (apiMessage.role() == ChatCompletionMessage.Role.USER + || apiMessage.role() == ChatCompletionMessage.Role.SYSTEM) { + + Object rawContent = apiMessage.rawContent(); + String refusal = apiMessage.refusal(); + String name = apiMessage.name(); + + MessageType messageType = MessageType.valueOf(apiMessage.role().name()); + + Map metadata = Map.of(); + // Map metadata = Map.of("refusal", refusal, "name", + // name); + + if (rawContent instanceof String textContent) { + return new UserMessage(messageType, textContent, List.of(), metadata); + } + else if (rawContent instanceof OpenAiApi.ChatCompletionMessage.MediaContent mediaContent) { + try { + var media = new Media(MimeTypeUtils.IMAGE_JPEG, new URL(mediaContent.imageUrl().url())); + return new UserMessage(messageType, mediaContent.text(), List.of(media), metadata); + } + catch (Exception e) { + throw new IllegalArgumentException( + "Unsupported message content type: " + rawContent.getClass()); + } + } + else { + throw new IllegalArgumentException("Unsupported message content type: " + rawContent.getClass()); + } + + } + else if (apiMessage.role() == ChatCompletionMessage.Role.ASSISTANT) { + + List toolCalls = null; + if (!CollectionUtils.isEmpty(apiMessage.toolCalls())) { + toolCalls = apiMessage.toolCalls().stream().map(toolCall -> { + return new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(), toolCall.function().name(), + toolCall.function().arguments()); + }).toList(); + } + return new AssistantMessage(apiMessage.content(), Map.of(), toolCalls); + } + else if (apiMessage.role() == ChatCompletionMessage.Role.TOOL) { + String functionName = apiMessage.name(); + String callId = apiMessage.toolCallId(); + List toolResponses = List + .of(new ToolResponseMessage.ToolResponse(callId, functionName, "" + apiMessage.rawContent())); + return new ToolResponseMessage(toolResponses, Map.of()); + } + else { + throw new IllegalArgumentException("Unsupported message type: " + apiMessage.role()); + } + }).map(abstractMessage -> (Message) abstractMessage).toList(); + + return new Prompt(messages, chatOptionsBuilder.build()); + } + + public static OpenAiChatOptions.Builder toChatOptions(ChatCompletionRequest chatCompletionRequest) { + + // 1. Convert the Options + Builder optionsBuilder = OpenAiChatOptions.builder(); + + List tools = chatCompletionRequest.tools(); + + if (!CollectionUtils.isEmpty(tools)) { + List tooDefinitions = tools.stream().map(tool -> { + return new FunctionDefinition(tool.function().name(), tool.function().description(), + ModelOptionsUtils.toJsonString(tool.function().parameters())); + }).map(fd -> (FunctionCallback) fd).toList(); + + optionsBuilder.withFunctionCallbacks(tooDefinitions); + } + + if (chatCompletionRequest.model() != null) { + optionsBuilder.withModel(chatCompletionRequest.model()); + } + if (chatCompletionRequest.frequencyPenalty() != null) { + optionsBuilder.withFrequencyPenalty(chatCompletionRequest.frequencyPenalty()); + } + if (chatCompletionRequest.logitBias() != null) { + optionsBuilder.withLogitBias(chatCompletionRequest.logitBias()); + } + if (chatCompletionRequest.logprobs() != null) { + optionsBuilder.withLogprobs(chatCompletionRequest.logprobs()); + } + if (chatCompletionRequest.topLogprobs() != null) { + optionsBuilder.withTopLogprobs(chatCompletionRequest.topLogprobs()); + } + if (chatCompletionRequest.maxTokens() != null) { + optionsBuilder.withMaxTokens(chatCompletionRequest.maxTokens()); + } + if (chatCompletionRequest.n() != null) { + optionsBuilder.withN(chatCompletionRequest.n()); + } + if (chatCompletionRequest.presencePenalty() != null) { + optionsBuilder.withPresencePenalty(chatCompletionRequest.presencePenalty()); + } + if (chatCompletionRequest.responseFormat() != null) { + optionsBuilder.withResponseFormat(chatCompletionRequest.responseFormat()); + } + if (chatCompletionRequest.seed() != null) { + optionsBuilder.withSeed(chatCompletionRequest.seed()); + } + if (chatCompletionRequest.stop() != null) { + optionsBuilder.withStop(chatCompletionRequest.stop()); + } + if (chatCompletionRequest.stream() != null) { + // ??? + } + if (chatCompletionRequest.temperature() != null) { + optionsBuilder.withTemperature(chatCompletionRequest.temperature()); + } + if (chatCompletionRequest.topP() != null) { + optionsBuilder.withTopP(chatCompletionRequest.topP()); + } + if (chatCompletionRequest.toolChoice() != null) { + optionsBuilder.withToolChoice("" + chatCompletionRequest.toolChoice()); + } + if (chatCompletionRequest.parallelToolCalls() != null) { + optionsBuilder.withParallelToolCalls(chatCompletionRequest.parallelToolCalls()); + } + if (chatCompletionRequest.user() != null) { + optionsBuilder.withUser(chatCompletionRequest.user()); + } + + return optionsBuilder; + } + + /** + * Converts the Spring AI {@link ChatResponse} into OpenAI {@link ChatCompletion}. + * @param chatResponse the Spring AI Chat Response to convert. + * @return the converted OpenAI Chat Completion. + */ + public static ChatCompletion toChatCompletion(ChatResponse chatResponse) { + + List choices = new ArrayList<>(chatResponse.getResults().size()); + + int index = 0; + + for (Generation generation : chatResponse.getResults()) { + var openAiMessage = toOpenAiMessage(generation.getOutput()); + var finishReason = ChatCompletionFinishReason.valueOf(generation.getMetadata().getFinishReason()); + choices.add(new Choice(finishReason, index, openAiMessage, null)); + index++; + } + + String id = chatResponse.getMetadata().getId(); + String model = chatResponse.getMetadata().getModel(); + Usage springAiUsage = chatResponse.getMetadata().getUsage(); + OpenAiApi.Usage usage = new OpenAiApi.Usage(springAiUsage.getGenerationTokens().intValue(), + springAiUsage.getPromptTokens().intValue(), springAiUsage.getTotalTokens().intValue()); + + return new ChatCompletion(id, choices, new Date().getTime(), model, null, "chat.completion", usage); + } + + public static ChatCompletionChunk toChatCompletionChunk(ChatResponse chatResponse) { + + List choices = new ArrayList<>(chatResponse.getResults().size()); + + int index = 0; + + for (Generation generation : chatResponse.getResults()) { + var openAiMessage = toOpenAiMessage(generation.getOutput()); + ChatCompletionFinishReason finishReason = (!StringUtils.hasText(generation.getMetadata().getFinishReason())) + ? null : ChatCompletionFinishReason.valueOf(generation.getMetadata().getFinishReason()); + choices.add(new ChunkChoice(finishReason, index, openAiMessage, null)); + index++; + } + + String id = chatResponse.getMetadata().getId(); + String model = chatResponse.getMetadata().getModel(); + Usage springAiUsage = chatResponse.getMetadata().getUsage(); + OpenAiApi.Usage usage = new OpenAiApi.Usage(springAiUsage.getGenerationTokens().intValue(), + springAiUsage.getPromptTokens().intValue(), springAiUsage.getTotalTokens().intValue()); + + return new ChatCompletionChunk(id, choices, new Date().getTime(), model, null, "chat.completion.chunk", usage); + } + + private static ChatCompletionMessage toOpenAiMessage(Message message) { + + if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { + Object content = message.getContent(); + if (message instanceof UserMessage userMessage) { + if (!CollectionUtils.isEmpty(userMessage.getMedia())) { + List contentList = new ArrayList<>(List.of(new MediaContent(message.getContent()))); + + contentList.addAll(userMessage.getMedia() + .stream() + .map(media -> new MediaContent( + new MediaContent.ImageUrl(fromMediaData(media.getMimeType(), media.getData())))) + .toList()); + + content = contentList; + } + } + + return new ChatCompletionMessage(content, + ChatCompletionMessage.Role.valueOf(message.getMessageType().name())); + } + else if (message.getMessageType() == MessageType.ASSISTANT) { + var assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = new ArrayList<>(); + for (int toolCallIndex = 0; toolCallIndex < assistantMessage.getToolCalls().size(); toolCallIndex++) { + var toolCall = assistantMessage.getToolCalls().get(toolCallIndex); + var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); + toolCalls.add(new ToolCall(toolCallIndex, toolCall.id(), toolCall.type(), function)); + } + } + return new ChatCompletionMessage(assistantMessage.getContent(), ChatCompletionMessage.Role.ASSISTANT, null, + null, toolCalls, null); + } + else { + throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); + } + } + + private static String fromMediaData(MimeType mimeType, Object mediaContentData) { + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 + // encoded following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the + // user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + } + + public static List toOpenAiMessages(List messages) { + List chatCompletionMessages = messages.stream().map(message -> { + if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { + Object content = message.getContent(); + if (message instanceof UserMessage userMessage) { + if (!CollectionUtils.isEmpty(userMessage.getMedia())) { + List contentList = new ArrayList<>( + List.of(new MediaContent(message.getContent()))); + + contentList.addAll(userMessage.getMedia() + .stream() + .map(media -> new MediaContent( + new MediaContent.ImageUrl(fromMediaData(media.getMimeType(), media.getData())))) + .toList()); + + content = contentList; + } + } + + return List.of(new ChatCompletionMessage(content, + ChatCompletionMessage.Role.valueOf(message.getMessageType().name()))); + } + else if (message.getMessageType() == MessageType.ASSISTANT) { + var assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { + var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); + return new ToolCall(toolCall.id(), toolCall.type(), function); + }).toList(); + } + return List.of(new ChatCompletionMessage(assistantMessage.getContent(), + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null)); + } + else if (message.getMessageType() == MessageType.TOOL) { + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + + toolMessage.getResponses().forEach(response -> { + Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); + Assert.isTrue(response.name() != null, "ToolResponseMessage must have a name"); + }); + + return toolMessage.getResponses() + .stream() + .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), + tr.id(), null, null)) + .toList(); + } + else { + throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); + } + }).flatMap(List::stream).toList(); + + return chatCompletionMessages; + } + + public static void main(String[] args) { + String request1 = """ + { + "messages": [ + { + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + "role": "user" + } + ], + "model": "gpt-4o-mini", + "stream": false, + "tools": [ + { + "type": "function", + "function": { + "description": "Get the weather in location", + "name": "getCurrentWeather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": [ + "C", + "F" + ] + } + }, + "required": [ + "location", + "lat", + "lon", + "unit" + ] + } + } + } + ] + } + """; + + String request2 = """ + { + "messages": [ + { + "content": "What's the weather like in San Francisco, Tokyo, and Paris?", + "role": "user" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_rzR55tsCemPcEXcyvXtt9v5H", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\\"location\\": \\"San Francisco, CA\\", \\"lat\\": 37.7749, \\"lon\\": -122.4194, \\"unit\\": \\"C\\"}" + } + }, + { + "id": "call_ZOEyq4knGZxFn9eLYBncHzuE", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\\"location\\": \\"Tokyo, Japan\\", \\"lat\\": 35.682839, \\"lon\\": 139.759455, \\"unit\\": \\"C\\"}" + } + }, + { + "id": "call_tZwspDn3nxkl4yodtAvlfeLt", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\\"location\\": \\"Paris, France\\", \\"lat\\": 48.8566, \\"lon\\": 2.3522, \\"unit\\": \\"C\\"}" + } + } + ] + }, + { + "content": "{\\"temp\\":30.0,\\"feels_like\\":15.0,\\"temp_min\\":20.0,\\"temp_max\\":2.0,\\"pressure\\":53,\\"humidity\\":45,\\"unit\\":\\"C\\"}", + "role": "tool", + "name": "getCurrentWeather", + "tool_call_id": "call_rzR55tsCemPcEXcyvXtt9v5H" + }, + { + "content": "{\\"temp\\":10.0,\\"feels_like\\":15.0,\\"temp_min\\":20.0,\\"temp_max\\":2.0,\\"pressure\\":53,\\"humidity\\":45,\\"unit\\":\\"C\\"}", + "role": "tool", + "name": "getCurrentWeather", + "tool_call_id": "call_ZOEyq4knGZxFn9eLYBncHzuE" + }, + { + "content": "{\\"temp\\":15.0,\\"feels_like\\":15.0,\\"temp_min\\":20.0,\\"temp_max\\":2.0,\\"pressure\\":53,\\"humidity\\":45,\\"unit\\":\\"C\\"}", + "role": "tool", + "name": "getCurrentWeather", + "tool_call_id": "call_tZwspDn3nxkl4yodtAvlfeLt" + } + ], + "model": "gpt-4o-mini", + "stream": false, + "tools": [ + { + "type": "function", + "function": { + "description": "Get the weather in location", + "name": "getCurrentWeather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": [ + "C", + "F" + ] + } + }, + "required": [ + "location", + "lat", + "lon", + "unit" + ] + } + } + } + ] + } + """; + + ChatCompletionRequest chatCompletionRequest1 = ModelOptionsUtils.jsonToObject(request1, + ChatCompletionRequest.class); + + ChatCompletionRequest chatCompletionRequest2 = ModelOptionsUtils.jsonToObject(request2, + ChatCompletionRequest.class); + + System.out.println(chatCompletionRequest2); + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 55f999813b..3c04ebbc0b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -15,9 +15,16 @@ */ package org.springframework.ai.openai; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -63,19 +70,13 @@ import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; - /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} * backed by {@link OpenAiApi}. @@ -189,6 +190,7 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { + super(functionCallbackContext, options, toolFunctionCallbacks); Assert.notNull(openAiApi, "OpenAiApi must not be null"); @@ -259,8 +261,9 @@ public ChatResponse call(Prompt prompt) { }); - if (response != null && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -330,7 +333,7 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 3a2d8695b8..e89c560e19 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; @@ -171,6 +172,14 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @JsonIgnore private Set functions = new HashSet<>(); + /** + * If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. + * It is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. + * If false, the Spring AI will handle the function calls internally. + */ + @JsonIgnore + private Boolean proxyToolCalls; + /** * Optional HTTP headers to be added to the chat completion request. */ @@ -307,8 +316,12 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public Builder withHttpHeaders(Map httpHeaders) { - Assert.notNull(httpHeaders, "HTTP headers must not be null"); this.options.httpHeaders = httpHeaders; return this; } @@ -468,6 +481,15 @@ public String getToolChoice() { return this.toolChoice; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + public void setToolChoice(String toolChoice) { this.toolChoice = toolChoice; } @@ -521,152 +543,6 @@ public Integer getTopK() { return null; } - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); - result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode()); - result = prime * result + ((logprobs == null) ? 0 : logprobs.hashCode()); - result = prime * result + ((topLogprobs == null) ? 0 : topLogprobs.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((n == null) ? 0 : n.hashCode()); - result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((streamOptions == null) ? 0 : streamOptions.hashCode()); - result = prime * result + ((seed == null) ? 0 : seed.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((user == null) ? 0 : user.hashCode()); - result = prime * result + ((parallelToolCalls == null) ? 0 : parallelToolCalls.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - OpenAiChatOptions other = (OpenAiChatOptions) obj; - if (this.model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) - return false; - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) - return false; - if (this.logitBias == null) { - if (other.logitBias != null) - return false; - } - else if (!this.logitBias.equals(other.logitBias)) - return false; - if (this.logprobs == null) { - if (other.logprobs != null) - return false; - } - else if (!this.logprobs.equals(other.logprobs)) - return false; - if (this.topLogprobs == null) { - if (other.topLogprobs != null) - return false; - } - else if (!this.topLogprobs.equals(other.topLogprobs)) - return false; - if (this.maxTokens == null) { - if (other.maxTokens != null) - return false; - } - else if (!this.maxTokens.equals(other.maxTokens)) - return false; - if (this.n == null) { - if (other.n != null) - return false; - } - else if (!this.n.equals(other.n)) - return false; - if (this.presencePenalty == null) { - if (other.presencePenalty != null) - return false; - } - else if (!this.presencePenalty.equals(other.presencePenalty)) - return false; - if (this.responseFormat == null) { - if (other.responseFormat != null) - return false; - } - else if (!this.responseFormat.equals(other.responseFormat)) - return false; - if (this.streamOptions == null) { - if (other.streamOptions != null) - return false; - } - else if (!this.streamOptions.equals(other.streamOptions)) - return false; - if (this.seed == null) { - if (other.seed != null) - return false; - } - else if (!this.seed.equals(other.seed)) - return false; - if (this.stop == null) { - if (other.stop != null) - return false; - } - else if (!stop.equals(other.stop)) - return false; - if (this.temperature == null) { - if (other.temperature != null) - return false; - } - else if (!this.temperature.equals(other.temperature)) - return false; - if (this.topP == null) { - if (other.topP != null) - return false; - } - else if (!topP.equals(other.topP)) - return false; - if (this.tools == null) { - if (other.tools != null) - return false; - } - else if (!tools.equals(other.tools)) - return false; - if (this.toolChoice == null) { - if (other.toolChoice != null) - return false; - } - else if (!toolChoice.equals(other.toolChoice)) - return false; - if (this.user == null) { - if (other.user != null) - return false; - } - else if (!this.user.equals(other.user)) - return false; - else if (this.parallelToolCalls == null) { - if (other.parallelToolCalls != null) - return false; - } - else if (!this.parallelToolCalls.equals(other.parallelToolCalls)) - return false; - - return true; - } - @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -695,9 +571,42 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withHttpHeaders(fromOptions.getHttpHeaders()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } + @Override + public int hashCode() { + return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, + this.maxTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, + this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, + this.functionCallbacks, this.functions, this.httpHeaders, this.proxyToolCalls); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + OpenAiChatOptions other = (OpenAiChatOptions) o; + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.logitBias, other.logitBias) && Objects.equals(this.logprobs, other.logprobs) + && Objects.equals(this.topLogprobs, other.topLogprobs) + && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) + && Objects.equals(this.presencePenalty, other.presencePenalty) + && Objects.equals(this.responseFormat, other.responseFormat) + && Objects.equals(this.streamOptions, other.streamOptions) && Objects.equals(this.seed, other.seed) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) + && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) + && Objects.equals(this.parallelToolCalls, other.parallelToolCalls) + && Objects.equals(this.functionCallbacks, other.functionCallbacks) + && Objects.equals(this.functions, other.functions) + && Objects.equals(this.httpHeaders, other.httpHeaders) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls); + } + @Override public String toString() { return "OpenAiChatOptions: " + ModelOptionsUtils.toJsonString(this); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 64e0d58cef..b88f795063 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -756,6 +756,8 @@ public MediaContent(ImageUrl imageUrl) { /** * The relevant tool call. * + * @param index The index of the tool call in the list of tool calls. Required in + * case of streaming. * @param id The ID of the tool call. This ID must be referenced when you submit * the tool outputs in using the Submit tool outputs to run endpoint. * @param type The type of tool call the output is required for. For now, this is @@ -764,9 +766,14 @@ public MediaContent(ImageUrl imageUrl) { */ @JsonInclude(Include.NON_NULL) public record ToolCall(// @formatter:off + @JsonProperty("index") Integer index, @JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) {// @formatter:on + + public ToolCall(String id, String type, ChatCompletionFunction function) { + this(null, id, type, function); + } } /** diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 6bfc6faa5e..37a619ab92 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -64,7 +64,7 @@ @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") -class OpenAiChatModelIT extends AbstractIT { +public class OpenAiChatModelIT extends AbstractIT { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java new file mode 100644 index 0000000000..7d1d362baf --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java @@ -0,0 +1,235 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.chat; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.openai.OpeAiApiAdapter; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.api.tool.MockWeatherService.Response; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.util.CollectionUtils; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.micrometer.observation.ObservationRegistry; + +@SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class OpenAiChatModelProxyToolCallsIT { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class); + + private static final String DEFAULT_MODEL = "gpt-4o-mini"; + + @Autowired + private OpenAiChatModel chatModel; + + @Autowired + private FunctionCallHelper functionCallUtils; + + @Test + void functionCallTest() { + + var weatherService = new MockWeatherService(); + + List messages = List + .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); + + FunctionCallback functionDefinition = new OpeAiApiAdapter.FunctionDefinition("getCurrentWeather", + "Get the weather in location", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "lat": { + "type": "number", + "description": "The city latitude" + }, + "lon": { + "type": "number", + "description": "The city longitude" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "lat", "lon", "unit"] + } + """); + + var promptOptions = OpenAiChatOptions.builder().withFunctionCallbacks(List.of(functionDefinition)).build(); + + var prompt = new Prompt(messages, promptOptions); + + boolean isToolCall = false; + + ChatResponse chatResponse = null; + + do { + + chatResponse = chatModel.call(prompt); + + try { + System.out.println("ChatCompletion:\n" + new ObjectMapper().writerWithDefaultPrettyPrinter() + .writeValueAsString(OpeAiApiAdapter.toChatCompletion(chatResponse))); + } + catch (Exception e) { + e.printStackTrace(); + } + + // We will have to convert the chatResponse into OpenAI assistant message. + + // Code that the Python tools will have to implement + isToolCall = functionCallUtils.isToolCall(chatResponse, + Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name())); + + if (isToolCall) { + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + assertThat(toolCallGeneration).isNotEmpty(); + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + var functionName = toolCall.name(); + + assertThat(functionName).isEqualTo("getCurrentWeather"); + + String functionArguments = toolCall.arguments(); + + MockWeatherService.Request functionRequest = ModelOptionsUtils.jsonToObject(functionArguments, + MockWeatherService.Request.class); + + Response functionResponse = weatherService.apply(functionRequest); + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, + ModelOptionsUtils.toJsonString(functionResponse))); + } + + ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); + + List toolCallConversation = functionCallUtils + .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); + + assertThat(toolCallConversation).isNotEmpty(); + + prompt = new Prompt(toolCallConversation, prompt.getOptions()); + } + } + while (isToolCall); + + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + /** + * Helper class that reuses the {@link AbstractToolCallSupport} to implement the + * function call handling logic on the client side. + */ + public static class FunctionCallHelper extends AbstractToolCallSupport { + + protected FunctionCallHelper(FunctionCallbackContext functionCallbackContext, + FunctionCallingOptions functionCallingOptions, List toolFunctionCallbacks) { + super(functionCallbackContext, functionCallingOptions, toolFunctionCallbacks); + } + + @Override + public boolean isToolCall(ChatResponse chatResponse, Set toolCallFinishReasons) { + return super.isToolCall(chatResponse, toolCallFinishReasons); + } + + @Override + public List buildToolCallConversation(List previousMessages, + AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) { + return super.buildToolCallConversation(previousMessages, assistantMessage, toolResponseMessage); + } + + @Override + public List handleToolCalls(Prompt prompt, ChatResponse response) { + return super.handleToolCalls(prompt, response); + } + + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiApi chatCompletionApi() { + return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + } + + @Bean + public OpenAiChatModel openAiClient(OpenAiApi openAiApi, List toolFunctionCallbacks) { + // enable the proxy tool calls option. + var options = OpenAiChatOptions.builder().withModel(DEFAULT_MODEL).withProxyToolCalls(true).build(); + + return new OpenAiChatModel(openAiApi, options, null, toolFunctionCallbacks, + RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); + } + + @Bean + public FunctionCallHelper functionCallUtils(List toolFunctionCallbacks) { + OpenAiChatOptions functionCallingOptions = OpenAiChatOptions.builder().build(); + return new FunctionCallHelper(null, functionCallingOptions, toolFunctionCallbacks); + } + + } + +} \ No newline at end of file diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java similarity index 98% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index 6d8cb41d17..ac395cf66c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; +package org.springframework.ai.openai.chat.proxy; import static org.assertj.core.api.Assertions.assertThat; @@ -50,6 +50,8 @@ import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.openai.chat.OpenAiChatModelIT; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java similarity index 98% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index ea3f603755..892c865c2f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat; +package org.springframework.ai.openai.chat.proxy; import static org.assertj.core.api.Assertions.assertThat; @@ -50,6 +50,8 @@ import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.openai.chat.OpenAiChatModelIT; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 1be9935e67..b7ea5051d3 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -180,7 +180,8 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - if (isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) { + if (!isProxyToolCalls(prompt, this.defaultOptions) + && isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. @@ -209,7 +210,7 @@ public Flux stream(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - if (isToolCall(chatResponse, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(FinishReason.STOP.name(), FinishReason.FINISH_REASON_UNSPECIFIED.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 062089cb9d..ce46d4f97c 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -115,6 +115,8 @@ public enum TransportType { @JsonIgnore private boolean googleSearchRetrieval = false; + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on @@ -194,6 +196,11 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) { return this; } + public Builder withProxyToolCalls(boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public VertexAiGeminiChatOptions build() { return this.options; } @@ -321,6 +328,15 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) { this.googleSearchRetrieval = googleSearchRetrieval; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public boolean equals(Object o) { if (this == o) @@ -333,13 +349,13 @@ public boolean equals(Object o) { && Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model) && Objects.equals(responseMimeType, that.responseMimeType) && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions); + && Objects.equals(functions, that.functions) && Objects.equals(proxyToolCalls, that.proxyToolCalls); } @Override public int hashCode() { return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model, - responseMimeType, functionCallbacks, functions, googleSearchRetrieval); + responseMimeType, functionCallbacks, functions, googleSearchRetrieval, proxyToolCalls); } @Override @@ -370,6 +386,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setFunctions(fromOptions.getFunctions()); options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); + options.setProxyToolCalls(fromOptions.getProxyToolCalls()); return options; } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 2a9293c355..6bbe65eb9b 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -177,7 +177,7 @@ public ChatResponse call(Prompt prompt) { ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - if (isToolCall(chatResponse, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message @@ -241,7 +241,7 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index b495eb6679..d33f04d79f 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -123,6 +123,9 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { @NestedConfigurationProperty @JsonIgnore private Set functions = new HashSet<>(); + + @JsonIgnore + private Boolean proxyToolCalls; // @formatter:on public static Builder builder() { @@ -208,6 +211,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.proxyToolCalls = proxyToolCalls; + return this; + } + public ZhiPuAiChatOptions build() { return this.options; } @@ -346,6 +354,15 @@ public Integer getTopK() { return null; } + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public int hashCode() { final int prime = 31; @@ -358,6 +375,7 @@ public int hashCode() { result = prime * result + ((tools == null) ? 0 : tools.hashCode()); result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); result = prime * result + ((user == null) ? 0 : user.hashCode()); + result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); return result; } @@ -430,6 +448,12 @@ else if (!this.requestId.equals(other.requestId)) } else if (!this.doSample.equals(other.doSample)) return false; + if (this.proxyToolCalls == null) { + if (other.proxyToolCalls != null) + return false; + } + else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) + return false; return true; } @@ -452,6 +476,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .withDoSample(fromOptions.getDoSample()) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) + .withProxyToolCalls(fromOptions.getProxyToolCalls()) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java index af05a1cfd3..903c304dc3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java @@ -64,6 +64,7 @@ public static Builder from(AdvisedRequest from) { builder.systemParams = from.systemParams; builder.advisors = from.advisors; builder.advisorParams = from.advisorParams; + builder.advisorParams = from.advisorParams; return builder; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 67b81abeef..dc7c7a5a2e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -74,9 +74,9 @@ static Builder builder(ChatModel chatModel, ObservationRegistry observationRegis ChatClientRequestSpec prompt(); - ChatClientPromptRequestSpec prompt(String content); + ChatClientRequestSpec prompt(String content); - ChatClientPromptRequestSpec prompt(Prompt prompt); + ChatClientRequestSpec prompt(Prompt prompt); /** * Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 21785d36d0..13f80769a6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -32,9 +32,9 @@ import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor.StreamResponseMode; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; @@ -85,12 +85,9 @@ public class DefaultChatClient implements ChatClient { private static final ChatClientObservationConvention DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION = new DefaultChatClientObservationConvention(); - private final ChatModel chatModel; - private final DefaultChatClientRequestSpec defaultChatClientRequest; - public DefaultChatClient(ChatModel chatModel, DefaultChatClientRequestSpec defaultChatClientRequest) { - this.chatModel = chatModel; + public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { this.defaultChatClientRequest = defaultChatClientRequest; } @@ -100,13 +97,19 @@ public ChatClientRequestSpec prompt() { } @Override - public ChatClientPromptRequestSpec prompt(String content) { - return new DefaultChatClientPromptRequestSpec(this.chatModel, new Prompt(content)); + public ChatClientRequestSpec prompt(String content) { + return prompt(new Prompt(content)); } - @Override - public ChatClientPromptRequestSpec prompt(Prompt prompt) { - return new DefaultChatClientPromptRequestSpec(this.chatModel, prompt); + public ChatClientRequestSpec prompt(Prompt prompt) { + + DefaultChatClientRequestSpec spec = new DefaultChatClientRequestSpec(this.defaultChatClientRequest); + spec.messages(prompt.getInstructions()); + if (prompt.getOptions() != null) { + spec.options(prompt.getOptions()); + } + + return spec; } /** @@ -997,25 +1000,4 @@ public Flux content() { } - public static class DefaultChatClientPromptRequestSpec implements ChatClientPromptRequestSpec { - - private final ChatModel chatModel; - - private final Prompt prompt; - - public DefaultChatClientPromptRequestSpec(ChatModel chatModel, Prompt prompt) { - this.chatModel = chatModel; - this.prompt = prompt; - } - - public CallPromptResponseSpec call() { - return new DefaultCallPromptResponseSpec(this.chatModel, this.prompt); - } - - public StreamPromptResponseSpec stream() { - return new DefaultStreamPromptResponseSpec(this.chatModel, this.prompt); - } - - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 4f22ad2c80..411f19e898 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -50,8 +50,6 @@ public class DefaultChatClientBuilder implements Builder { protected final DefaultChatClientRequestSpec defaultRequest; - private final ChatModel chatModel; - DefaultChatClientBuilder(ChatModel chatModel) { this(chatModel, ObservationRegistry.NOOP, null); } @@ -60,14 +58,13 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); - this.chatModel = chatModel; this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention); } public ChatClient build() { - return new DefaultChatClient(this.chatModel, this.defaultRequest); + return new DefaultChatClient(this.defaultRequest); } public Builder defaultAdvisors(Advisor... advisor) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index 6085385646..2b58ebf614 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -75,18 +75,18 @@ protected AbstractToolCallSupport(FunctionCallbackContext functionCallbackContex } } - private static List merge(FunctionCallingOptions funcitonOptions, + private static List merge(FunctionCallingOptions functionOptions, List toolFunctionCallbacks) { List toolFunctionCallbacksCopy = new ArrayList<>(); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { toolFunctionCallbacksCopy.addAll(toolFunctionCallbacks); } - if (!CollectionUtils.isEmpty(funcitonOptions.getFunctionCallbacks())) { - toolFunctionCallbacksCopy.addAll(funcitonOptions.getFunctionCallbacks()); + if (!CollectionUtils.isEmpty(functionOptions.getFunctionCallbacks())) { + toolFunctionCallbacksCopy.addAll(functionOptions.getFunctionCallbacks()); // Make sure that that function callbacks are are registered directly to the // functionCallbackRegister and not passed in the default options. - funcitonOptions.setFunctionCallbacks(List.of()); + functionOptions.setFunctionCallbacks(List.of()); } return toolFunctionCallbacksCopy; } @@ -220,6 +220,13 @@ protected boolean isToolCall(ChatResponse chatResponse, Set toolCallFini return generations.stream().anyMatch(g -> isToolCall(g, toolCallFinishReasons)); } + /** + * Check if the generation is a tool call. The tool call finish reasons are used to + * determine if the generation is a tool call. + * @param generation the generation to check. + * @param toolCallFinishReasons the tool call finish reasons to check. + * @return true if the generation is a tool call, false otherwise. + */ protected boolean isToolCall(Generation generation, Set toolCallFinishReasons) { var finishReason = (generation.getMetadata().getFinishReason() != null) ? generation.getMetadata().getFinishReason() : ""; @@ -229,4 +236,26 @@ protected boolean isToolCall(Generation generation, Set toolCallFinishRe .contains(finishReason.toLowerCase()); } + /** + * Check if the proxyToolCalls is enabled for the given prompt or the default tool + * call options. The prompt options take precedence over the default options. When the + * proxyToolCalls is enabled the ChatModel implementation will not handle the function + * calling internally. The tool call and tool response messages are exposed outside + * the ChatModel implementation. + * @param prompt the prompt to check. + * @param defaultOptions the default tool call options to check. + * @return true if the proxyToolCalls is enabled, false otherwise. + */ + protected boolean isProxyToolCalls(Prompt prompt, FunctionCallingOptions defaultOptions) { + if (prompt.getOptions() instanceof FunctionCallingOptions functionCallOptions + && functionCallOptions.getProxyToolCalls() != null) { + return functionCallOptions.getProxyToolCalls(); + } + else if (defaultOptions.getProxyToolCalls() != null) { + return defaultOptions.getProxyToolCalls(); + } + + return false; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index feb8e06756..3d36b1bfbf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -50,7 +50,7 @@ public Prompt(Message message) { } public Prompt(List messages) { - this.messages = messages; + this(messages, null); } public Prompt(String contents, ChatOptions chatOptions) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index df603d2b38..f953e907d3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -54,6 +54,16 @@ public interface FunctionCallingOptions { */ void setFunctions(Set functions); + default Boolean getProxyToolCalls() { + return false; + } + + default void setProxyToolCalls(Boolean proxyToolCalls) { + if (proxyToolCalls != null) { + throw new UnsupportedOperationException("Setting Proxy Tool Calls are not supported!"); + } + } + /** * @return Returns FunctionCallingOptionsBuilder to create a new instance of * FunctionCallingOptions. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index 5c1d7c0520..3c00356831 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -102,6 +102,11 @@ public FunctionCallingOptionsBuilder withTopP(Double topP) { return this; } + public FunctionCallingOptionsBuilder withProxyToolCalls(Boolean proxyToolCalls) { + this.options.setProxyToolCalls(proxyToolCalls); + return this; + } + public PortableFunctionCallingOptions build() { return this.options; } @@ -128,6 +133,8 @@ public static class PortableFunctionCallingOptions implements FunctionCallingOpt private Double topP; + private Boolean proxyToolCalls = false; + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -220,6 +227,15 @@ public void setTopP(Double topP) { this.topP = topP; } + @Override + public Boolean getProxyToolCalls() { + return proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + @Override public ChatOptions copy() { return new FunctionCallingOptionsBuilder().withModel(this.model) @@ -232,6 +248,7 @@ public ChatOptions copy() { .withTopP(this.topP) .withFunctions(this.functions) .withFunctionCallbacks(this.functionCallbacks) + .withProxyToolCalls(this.proxyToolCalls) .build(); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index a5fcc2fd00..f4191093f8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -103,8 +103,9 @@ The prefix `spring.ai.anthropic.chat` is the property prefix that lets you confi | spring.ai.anthropic.chat.options.stop-sequence | Custom text sequences that will cause the model to stop generating. Our models will normally stop when they have naturally completed their turn, which will result in a response stop_reason of "end_turn". If you want the model to stop generating when it encounters custom strings of text, you can use the stop_sequences parameter. If the model encounters one of the custom sequences, the response stop_reason value will be "stop_sequence" and the response stop_sequence value will contain the matched stop sequence. | - | spring.ai.anthropic.chat.options.top-p | Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature. | - | spring.ai.anthropic.chat.options.top-k | Only sample from the top K options for each subsequent token. Used to remove "long tail" low probability responses. Learn more technical details here. Recommended for advanced use cases only. You usually only need to use temperature. | - -| spring.ai.mistralai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - -| spring.ai.mistralai.chat.options.functionCallbacks | MistralAI Tool Function Callbacks to register with the ChatModel. | - +| spring.ai.anthropic.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.anthropic.chat.options.functionCallbacks | Tool Function Callbacks to register with the ChatModel. | - +| spring.ai.anthropic.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.anthropic.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index 92e7e3e5a7..63dafabb60 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -143,6 +143,7 @@ Deployments model name to provide as part of this completions request. | gpt-4o | spring.ai.azure.openai.chat.options.presencePenalty | A value that influences the probability of generated tokens appearing based on their existing presence in generated text. Positive values will make tokens less likely to appear when they already exist and increase the model's likelihood to output new topics. | - | spring.ai.azure.openai.chat.options.responseFormat | An object specifying the format that the model must output. Using `AzureOpenAiResponseFormat.JSON` enables JSON mode, which guarantees the message the model generates is valid JSON. Using AzureOpenAiResponseFormat.TEXT enables TEXT mode.| - | spring.ai.azure.openai.chat.options.frequencyPenalty | A value that influences the probability of generated tokens appearing based on their cumulative frequency in generated text. Positive values will make tokens less likely to appear as their frequency increases and decrease the likelihood of the model repeating the same statements verbatim. | - +| spring.ai.azure.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.azure.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc index 76a733f0fd..c4d68799f1 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc @@ -125,6 +125,7 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur | spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - | spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false +| spring.ai.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc index 9e0a6e0471..5d2680a493 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc @@ -102,6 +102,7 @@ The prefix `spring.ai.mistralai.chat` is the property prefix that lets you confi | spring.ai.mistralai.chat.options.toolChoice | Controls which (if any) function is called by the model. `none` means the model will not call a function and instead generates a message. `auto` means the model can pick between generating a message or calling a function. Specifying a particular function via `{"type: "function", "function": {"name": "my_function"}}` forces the model to call that function. `none` is the default when no functions are present. `auto` is the default if functions are present. | - | spring.ai.mistralai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.mistralai.chat.options.functionCallbacks | Mistral AI Tool Function Callbacks to register with the ChatModel. | - +| spring.ai.mistralai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== NOTE: You can override the common `spring.ai.mistralai.base-url` and `spring.ai.mistralai.api-key` for the `ChatModel` and `EmbeddingModel` implementations. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc index 56b135a72b..ee7a02033a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc @@ -101,6 +101,7 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur | spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - | spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false +| spring.ai.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc index ae73ee916c..d06b148b24 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/ollama-chat.adoc @@ -105,6 +105,7 @@ The remaining `options` properties are based on the link:https://github.com/olla | spring.ai.ollama.chat.options.penalize-newline | - | true | spring.ai.ollama.chat.options.stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. | - | spring.ai.ollama.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.ollama.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== TIP: All properties prefixed with `spring.ai.ollama.chat.options` can be overridden at runtime by adding request-specific <> to the `Prompt` call. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index 1e731a770e..f5c783b461 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -117,6 +117,7 @@ The `JSON_SCHEMA` type enables link:https://platform.openai.com/docs/guides/stru | spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false | spring.ai.openai.chat.options.parallel-tool-calls | Whether to enable link:https://platform.openai.com/docs/guides/function-calling/parallel-function-calling[parallel function calling] during tool use. | true | spring.ai.openai.chat.options.http-headers | Optional HTTP headers to be added to the chat completion request. To override the `api-key` you need to use an `Authorization` header key, and you have to prefix the key value with the `Bearer ` prefix. | - +| spring.ai.openai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== NOTE: You can override the common `spring.ai.openai.base-url` and `spring.ai.openai.api-key` for the `ChatModel` and `EmbeddingModel` implementations. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 73dc4b2a6b..dbbcdabba7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -78,6 +78,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo | spring.ai.vertex.ai.gemini.chat.options.frequencyPenalty | | - | spring.ai.vertex.ai.gemini.chat.options.presencePenalty | | - | spring.ai.vertex.ai.gemini.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.vertex.ai.gemini.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc index d946274232..328dc9a2bf 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/zhipuai-chat.adoc @@ -98,6 +98,7 @@ The prefix `spring.ai.zhipuai.chat` is the property prefix that lets you configu | spring.ai.zhipuai.chat.options.user | A unique identifier representing your end-user, which can help ZhiPuAI to monitor and detect abuse. | - | spring.ai.zhipuai.chat.options.requestId | The parameter is passed by the client and must ensure uniqueness. It is used to distinguish the unique identifier for each request. If the client does not provide it, the platform will generate it by default. | - | spring.ai.zhipuai.chat.options.doSample | When do_sample is set to true, the sampling strategy is enabled. If do_sample is false, the sampling strategy parameters temperature and top_p will not take effect. | true +| spring.ai.zhipuai.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false |==== NOTE: You can override the common `spring.ai.zhipuai.base-url` and `spring.ai.zhipuai.api-key` for the `ChatModel` implementations.