From 3f4c064d743350486786835d331ad96ace5be192 Mon Sep 17 00:00:00 2001 From: dafriz Date: Thu, 26 Sep 2024 00:43:40 +1000 Subject: [PATCH] Add support for max_completion_tokens in OpenAI chat options request An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. Replaces max_tokens field which is now deprecated. --- .../ai/openai/OpenAiChatOptions.java | 31 ++++++++++++++++--- .../ai/openai/api/OpenAiApi.java | 13 +++++--- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index e89c560e19..6303ffe3bc 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -82,6 +82,11 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions { * tokens and generated tokens is limited by the model's context length. */ private @JsonProperty("max_tokens") Integer maxTokens; + /** + * An upper bound for the number of tokens that can be generated for a completion, + * including visible output tokens and reasoning tokens. + */ + private @JsonProperty("max_completion_tokens") Integer maxCompletionTokens; /** * How many chat completion choices to generate for each input message. Note that you will be charged based * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. @@ -239,6 +244,11 @@ public Builder withMaxTokens(Integer maxTokens) { return this; } + public Builder withMaxCompletionTokens(Integer maxCompletionTokens) { + this.options.maxCompletionTokens = maxCompletionTokens; + return this; + } + public Builder withN(Integer n) { this.options.n = n; return this; @@ -391,6 +401,14 @@ public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + public Integer getMaxCompletionTokens() { + return maxCompletionTokens; + } + + public void setMaxCompletionTokens(Integer maxCompletionTokens) { + this.maxCompletionTokens = maxCompletionTokens; + } + public Integer getN() { return this.n; } @@ -556,6 +574,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .withLogprobs(fromOptions.getLogprobs()) .withTopLogprobs(fromOptions.getTopLogprobs()) .withMaxTokens(fromOptions.getMaxTokens()) + .withMaxCompletionTokens(fromOptions.getMaxCompletionTokens()) .withN(fromOptions.getN()) .withPresencePenalty(fromOptions.getPresencePenalty()) .withResponseFormat(fromOptions.getResponseFormat()) @@ -578,9 +597,10 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { @Override public int hashCode() { return Objects.hash(this.model, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, - this.maxTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, - this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, - this.functionCallbacks, this.functions, this.httpHeaders, this.proxyToolCalls); + this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, + this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, + this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders, + this.proxyToolCalls); } @Override @@ -593,8 +613,9 @@ public boolean equals(Object o) { return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) && Objects.equals(this.logitBias, other.logitBias) && Objects.equals(this.logprobs, other.logprobs) && Objects.equals(this.topLogprobs, other.topLogprobs) - && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) - && Objects.equals(this.presencePenalty, other.presencePenalty) + && Objects.equals(this.maxTokens, other.maxTokens) + && Objects.equals(this.maxCompletionTokens, other.maxCompletionTokens) + && Objects.equals(this.n, other.n) && Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.streamOptions, other.streamOptions) && Objects.equals(this.seed, other.seed) && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 4eb00c374e..144c54f715 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -390,6 +390,8 @@ public Function(String description, String name, String jsonSchema) { * @param maxTokens The maximum number of tokens to generate in the chat completion. * The total length of input tokens and generated tokens is limited by the model's * context length. + * @param maxCompletionTokens An upper bound for the number of tokens that can be + * generated for a completion, including visible output tokens and reasoning tokens. * @param n How many chat completion choices to generate for each input message. Note * that you will be charged based on the number of generated tokens across all the * choices. Keep n as 1 to minimize costs. @@ -442,6 +444,7 @@ public record ChatCompletionRequest(// @formatter:off @JsonProperty("logprobs") Boolean logprobs, @JsonProperty("top_logprobs") Integer topLogprobs, @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("max_completion_tokens") Integer maxCompletionTokens, @JsonProperty("n") Integer n, @JsonProperty("presence_penalty") Double presencePenalty, @JsonProperty("response_format") ResponseFormat responseFormat, @@ -464,7 +467,7 @@ public record ChatCompletionRequest(// @formatter:off * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Double temperature) { - this(messages, model, null, null, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, null, null, null, null); } @@ -479,7 +482,7 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { - this(messages, model, null, null, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, null, stream, null, temperature, null, null, null, null, null); } @@ -495,7 +498,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { - this(messages, model, null, null, null, null, null, null, null, + this(messages, model, null, null, null, null, null, null, null, null, null, null, null, false, null, 0.8, null, tools, toolChoice, null, null); } @@ -509,7 +512,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, Boolean stream) { this(messages, null, null, null, null, null, null, null, null, - null, null, null, stream, null, null, null, + null, null, null, null, stream, null, null, null, null, null, null, null); } @@ -520,7 +523,7 @@ public ChatCompletionRequest(List messages, Boolean strea * @return A new {@link ChatCompletionRequest} with the specified stream options. */ public ChatCompletionRequest withStreamOptions(StreamOptions streamOptions) { - return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, n, presencePenalty, + return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, maxCompletionTokens, n, presencePenalty, responseFormat, seed, stop, stream, streamOptions, temperature, topP, tools, toolChoice, parallelToolCalls, user); }