Skip to content

Commit

Permalink
Align AzureOpenAiChatOptions with Azure ChatCompletionsOptions
Browse files Browse the repository at this point in the history
Add missing options from Azure ChatCompletionsOptions to Spring AI
AzureOpenAiChatOptions. The following fields have been added:

- seed
- logprobs
- topLogprobs
- enhancements

This change ensures better alignment between the two option sets,
improving compatibility and feature parity.

Resolves #889
  • Loading branch information
sobychacko authored and Mark Pollack committed Sep 24, 2024
1 parent 35e6113 commit 42dcb45
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import java.util.ArrayList;
Expand Down Expand Up @@ -92,6 +93,7 @@
* @author Thomas Vitale
* @author luocongqiu
* @author timostark
* @author Soby Chacko
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
*/
Expand Down Expand Up @@ -456,6 +458,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
: toSpringAiOptions.getDeploymentName());

mergedAzureOptions
.setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed());

mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs())
|| (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs()));

mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs()
: toSpringAiOptions.getTopLogProbs());

mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());

return mergedAzureOptions;
}

Expand Down Expand Up @@ -520,6 +534,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat()));
}

if (fromSpringAiOptions.getSeed() != null) {
mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed());
}

if (fromSpringAiOptions.isLogprobs() != null) {
mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs());
}

if (fromSpringAiOptions.getTopLogProbs() != null) {
mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs());
}

if (fromSpringAiOptions.getEnhancements() != null) {
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
}

return mergedAzureOptions;
}

Expand Down Expand Up @@ -566,6 +596,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
if (fromOptions.getResponseFormat() != null) {
copyOptions.setResponseFormat(fromOptions.getResponseFormat());
}
if (fromOptions.getSeed() != null) {
copyOptions.setSeed(fromOptions.getSeed());
}

copyOptions.setLogprobs(fromOptions.isLogprobs());

if (fromOptions.getTopLogprobs() != null) {
copyOptions.setTopLogprobs(fromOptions.getTopLogprobs());
}

if (fromOptions.getEnhancements() != null) {
copyOptions.setEnhancements(fromOptions.getEnhancements());
}

return copyOptions;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import java.util.ArrayList;
Expand All @@ -21,6 +22,7 @@
import java.util.Map;
import java.util.Set;

import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand All @@ -40,6 +42,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Soby Chacko
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
Expand Down Expand Up @@ -165,6 +168,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
@JsonIgnore
private Boolean proxyToolCalls;

/**
* Seed value for deterministic sampling such that the same seed and parameters return
* the same result.
*/
@JsonProperty(value = "seed")
private Long seed;

/**
* Whether to return log probabilities of the output tokens or not. If true, returns
* the log probabilities of each output token returned in the `content` of `message`.
* This option is currently not available on the `gpt-4-vision-preview` model.
*/
@JsonProperty(value = "log_probs")
private Boolean logprobs;

/*
* An integer between 0 and 5 specifying the number of most likely tokens to return at
* each token position, each with an associated log probability. `logprobs` must be
* set to `true` if this parameter is used.
*/
@JsonProperty(value = "top_log_probs")
private Integer topLogProbs;

/*
* If provided, the configuration options for available Azure OpenAI chat
* enhancements.
*/
@NestedConfigurationProperty
@JsonIgnore
private AzureChatEnhancementConfiguration enhancements;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -259,6 +293,30 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) {
return this;
}

public Builder withSeed(Long seed) {
Assert.notNull(seed, "seed must not be null");
this.options.seed = seed;
return this;
}

public Builder withLogprobs(Boolean logprobs) {
Assert.notNull(logprobs, "logprobs must not be null");
this.options.logprobs = logprobs;
return this;
}

public Builder withTopLogprobs(Integer topLogprobs) {
Assert.notNull(topLogprobs, "topLogprobs must not be null");
this.options.topLogProbs = topLogprobs;
return this;
}

public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) {
Assert.notNull(enhancements, "enhancements must not be null");
this.options.enhancements = enhancements;
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -404,6 +462,38 @@ public Integer getTopK() {
return null;
}

public Long getSeed() {
return this.seed;
}

public void setSeed(Long seed) {
this.seed = seed;
}

public Boolean isLogprobs() {
return this.logprobs;
}

public void setLogprobs(Boolean logprobs) {
this.logprobs = logprobs;
}

public Integer getTopLogProbs() {
return this.topLogProbs;
}

public void setTopLogProbs(Integer topLogProbs) {
this.topLogProbs = topLogProbs;
}

public AzureChatEnhancementConfiguration getEnhancements() {
return this.enhancements;
}

public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
this.enhancements = enhancements;
}

@Override
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
Expand Down Expand Up @@ -432,6 +522,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.withResponseFormat(fromOptions.getResponseFormat())
.withSeed(fromOptions.getSeed())
.withLogprobs(fromOptions.isLogprobs())
.withTopLogprobs(fromOptions.getTopLogProbs())
.withEnhancements(fromOptions.getEnhancements())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import org.junit.jupiter.api.Test;
Expand All @@ -34,6 +37,7 @@

/**
* @author Christian Tzolov
* @author Soby Chacko
*/
public class AzureChatCompletionsOptionsTests {

Expand All @@ -42,6 +46,9 @@ public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);

AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);

var defaultOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
.withTemperature(66.6)
Expand All @@ -53,6 +60,10 @@ public void createRequestWithChatOptions() {
.withStop(List.of("foo", "bar"))
.withTopP(0.69)
.withUser("user")
.withSeed(123L)
.withLogprobs(true)
.withTopLogprobs(5)
.withEnhancements(mockAzureChatEnhancementConfiguration)
.withResponseFormat(AzureOpenAiResponseFormat.TEXT)
.build();

Expand All @@ -72,8 +83,15 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
assertThat(requestOptions.getTopP()).isEqualTo(0.69);
assertThat(requestOptions.getUser()).isEqualTo("user");
assertThat(requestOptions.getSeed()).isEqualTo(123L);
assertThat(requestOptions.isLogprobs()).isTrue();
assertThat(requestOptions.getTopLogprobs()).isEqualTo(5);
assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration);
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class);

AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);

var runtimeOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("PROMPT_MODEL")
.withTemperature(99.9)
Expand All @@ -85,6 +103,10 @@ public void createRequestWithChatOptions() {
.withStop(List.of("foo", "bar"))
.withTopP(0.111)
.withUser("user2")
.withSeed(1234L)
.withLogprobs(true)
.withTopLogprobs(4)
.withEnhancements(anotherMockAzureChatEnhancementConfiguration)
.withResponseFormat(AzureOpenAiResponseFormat.JSON)
.build();

Expand All @@ -102,6 +124,10 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
assertThat(requestOptions.getTopP()).isEqualTo(0.111);
assertThat(requestOptions.getUser()).isEqualTo("user2");
assertThat(requestOptions.getSeed()).isEqualTo(1234L);
assertThat(requestOptions.isLogprobs()).isTrue();
assertThat(requestOptions.getTopLogprobs()).isEqualTo(4);
assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration);
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class);
}

Expand Down

0 comments on commit 42dcb45

Please sign in to comment.