From 42dcb45f32373d7c978d4cb7baad0b708a91eb67 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Fri, 20 Sep 2024 17:09:28 -0400 Subject: [PATCH] Align AzureOpenAiChatOptions with Azure ChatCompletionsOptions Add missing options from Azure ChatCompletionsOptions to Spring AI AzureOpenAiChatOptions. The following fields have been added: - seed - logprobs - topLogprobs - enhancements This change ensures better alignment between the two option sets, improving compatibility and feature parity. Resolves https://github.com/spring-projects/spring-ai/issues/889 --- .../ai/azure/openai/AzureOpenAiChatModel.java | 45 ++++++++- .../azure/openai/AzureOpenAiChatOptions.java | 96 ++++++++++++++++++- .../AzureChatCompletionsOptionsTests.java | 28 +++++- 3 files changed, 166 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index bffdbe2c74..5829d4e18a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.ArrayList; @@ -92,6 +93,7 @@ * @author Thomas Vitale * @author luocongqiu * @author timostark + * @author Soby Chacko * @see ChatModel * @see com.azure.ai.openai.OpenAIClient */ @@ -456,6 +458,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel() : toSpringAiOptions.getDeploymentName()); + mergedAzureOptions + .setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed()); + + mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs()) + || (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs())); + + mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs() + : toSpringAiOptions.getTopLogProbs()); + + mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null + ? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements()); + return mergedAzureOptions; } @@ -520,6 +534,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat())); } + if (fromSpringAiOptions.getSeed() != null) { + mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed()); + } + + if (fromSpringAiOptions.isLogprobs() != null) { + mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs()); + } + + if (fromSpringAiOptions.getTopLogProbs() != null) { + mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs()); + } + + if (fromSpringAiOptions.getEnhancements() != null) { + mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements()); + } + return mergedAzureOptions; } @@ -566,6 +596,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { if (fromOptions.getResponseFormat() != null) { copyOptions.setResponseFormat(fromOptions.getResponseFormat()); } + if (fromOptions.getSeed() != null) { + copyOptions.setSeed(fromOptions.getSeed()); + } + + copyOptions.setLogprobs(fromOptions.isLogprobs()); + + if (fromOptions.getTopLogprobs() != null) { + copyOptions.setTopLogprobs(fromOptions.getTopLogprobs()); + } + + if (fromOptions.getEnhancements() != null) { + copyOptions.setEnhancements(fromOptions.getEnhancements()); + } return copyOptions; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 6b85eeb966..5faa64ebb1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.ArrayList; @@ -21,6 +22,7 @@ import java.util.Map; import java.util.Set; +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -40,6 +42,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Soby Chacko */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @@ -165,6 +168,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio @JsonIgnore private Boolean proxyToolCalls; + /** + * Seed value for deterministic sampling such that the same seed and parameters return + * the same result. + */ + @JsonProperty(value = "seed") + private Long seed; + + /** + * Whether to return log probabilities of the output tokens or not. If true, returns + * the log probabilities of each output token returned in the `content` of `message`. + * This option is currently not available on the `gpt-4-vision-preview` model. + */ + @JsonProperty(value = "log_probs") + private Boolean logprobs; + + /* + * An integer between 0 and 5 specifying the number of most likely tokens to return at + * each token position, each with an associated log probability. `logprobs` must be + * set to `true` if this parameter is used. + */ + @JsonProperty(value = "top_log_probs") + private Integer topLogProbs; + + /* + * If provided, the configuration options for available Azure OpenAI chat + * enhancements. + */ + @NestedConfigurationProperty + @JsonIgnore + private AzureChatEnhancementConfiguration enhancements; + public static Builder builder() { return new Builder(); } @@ -259,6 +293,30 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withSeed(Long seed) { + Assert.notNull(seed, "seed must not be null"); + this.options.seed = seed; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + Assert.notNull(logprobs, "logprobs must not be null"); + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + Assert.notNull(topLogprobs, "topLogprobs must not be null"); + this.options.topLogProbs = topLogprobs; + return this; + } + + public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { + Assert.notNull(enhancements, "enhancements must not be null"); + this.options.enhancements = enhancements; + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } @@ -404,6 +462,38 @@ public Integer getTopK() { return null; } + public Long getSeed() { + return this.seed; + } + + public void setSeed(Long seed) { + this.seed = seed; + } + + public Boolean isLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogProbs() { + return this.topLogProbs; + } + + public void setTopLogProbs(Integer topLogProbs) { + this.topLogProbs = topLogProbs; + } + + public AzureChatEnhancementConfiguration getEnhancements() { + return this.enhancements; + } + + public void setEnhancements(AzureChatEnhancementConfiguration enhancements) { + this.enhancements = enhancements; + } + @Override public Boolean getProxyToolCalls() { return this.proxyToolCalls; @@ -432,6 +522,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withLogprobs(fromOptions.isLogprobs()) + .withTopLogprobs(fromOptions.getTopLogProbs()) + .withEnhancements(fromOptions.getEnhancements()) .build(); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index 6e8d8bd531..f7edea989b 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; import org.junit.jupiter.api.Test; @@ -34,6 +37,7 @@ /** * @author Christian Tzolov + * @author Soby Chacko */ public class AzureChatCompletionsOptionsTests { @@ -42,6 +46,9 @@ public void createRequestWithChatOptions() { OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito + .mock(AzureChatEnhancementConfiguration.class); + var defaultOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("DEFAULT_MODEL") .withTemperature(66.6) @@ -53,6 +60,10 @@ public void createRequestWithChatOptions() { .withStop(List.of("foo", "bar")) .withTopP(0.69) .withUser("user") + .withSeed(123L) + .withLogprobs(true) + .withTopLogprobs(5) + .withEnhancements(mockAzureChatEnhancementConfiguration) .withResponseFormat(AzureOpenAiResponseFormat.TEXT) .build(); @@ -72,8 +83,15 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar")); assertThat(requestOptions.getTopP()).isEqualTo(0.69); assertThat(requestOptions.getUser()).isEqualTo("user"); + assertThat(requestOptions.getSeed()).isEqualTo(123L); + assertThat(requestOptions.isLogprobs()).isTrue(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(5); + assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration); assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class); + AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito + .mock(AzureChatEnhancementConfiguration.class); + var runtimeOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("PROMPT_MODEL") .withTemperature(99.9) @@ -85,6 +103,10 @@ public void createRequestWithChatOptions() { .withStop(List.of("foo", "bar")) .withTopP(0.111) .withUser("user2") + .withSeed(1234L) + .withLogprobs(true) + .withTopLogprobs(4) + .withEnhancements(anotherMockAzureChatEnhancementConfiguration) .withResponseFormat(AzureOpenAiResponseFormat.JSON) .build(); @@ -102,6 +124,10 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar")); assertThat(requestOptions.getTopP()).isEqualTo(0.111); assertThat(requestOptions.getUser()).isEqualTo("user2"); + assertThat(requestOptions.getSeed()).isEqualTo(1234L); + assertThat(requestOptions.isLogprobs()).isTrue(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(4); + assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration); assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class); }