Skip to content

Commit

Permalink
Merge branch 'spring-projects:main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhangbin committed Sep 25, 2024
2 parents 1589d56 + c5f07e5 commit 36d62c6
Show file tree
Hide file tree
Showing 16 changed files with 409 additions and 90 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import java.util.ArrayList;
Expand Down Expand Up @@ -92,6 +93,7 @@
* @author Thomas Vitale
* @author luocongqiu
* @author timostark
* @author Soby Chacko
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
*/
Expand Down Expand Up @@ -456,6 +458,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
: toSpringAiOptions.getDeploymentName());

mergedAzureOptions
.setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed());

mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs())
|| (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs()));

mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs()
: toSpringAiOptions.getTopLogProbs());

mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());

return mergedAzureOptions;
}

Expand Down Expand Up @@ -520,6 +534,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat()));
}

if (fromSpringAiOptions.getSeed() != null) {
mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed());
}

if (fromSpringAiOptions.isLogprobs() != null) {
mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs());
}

if (fromSpringAiOptions.getTopLogProbs() != null) {
mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs());
}

if (fromSpringAiOptions.getEnhancements() != null) {
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
}

return mergedAzureOptions;
}

Expand Down Expand Up @@ -566,6 +596,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
if (fromOptions.getResponseFormat() != null) {
copyOptions.setResponseFormat(fromOptions.getResponseFormat());
}
if (fromOptions.getSeed() != null) {
copyOptions.setSeed(fromOptions.getSeed());
}

copyOptions.setLogprobs(fromOptions.isLogprobs());

if (fromOptions.getTopLogprobs() != null) {
copyOptions.setTopLogprobs(fromOptions.getTopLogprobs());
}

if (fromOptions.getEnhancements() != null) {
copyOptions.setEnhancements(fromOptions.getEnhancements());
}

return copyOptions;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import java.util.ArrayList;
Expand All @@ -21,6 +22,7 @@
import java.util.Map;
import java.util.Set;

import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand All @@ -40,6 +42,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Soby Chacko
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
Expand Down Expand Up @@ -165,6 +168,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
@JsonIgnore
private Boolean proxyToolCalls;

/**
* Seed value for deterministic sampling such that the same seed and parameters return
* the same result.
*/
@JsonProperty(value = "seed")
private Long seed;

/**
* Whether to return log probabilities of the output tokens or not. If true, returns
* the log probabilities of each output token returned in the `content` of `message`.
* This option is currently not available on the `gpt-4-vision-preview` model.
*/
@JsonProperty(value = "log_probs")
private Boolean logprobs;

/*
* An integer between 0 and 5 specifying the number of most likely tokens to return at
* each token position, each with an associated log probability. `logprobs` must be
* set to `true` if this parameter is used.
*/
@JsonProperty(value = "top_log_probs")
private Integer topLogProbs;

/*
* If provided, the configuration options for available Azure OpenAI chat
* enhancements.
*/
@NestedConfigurationProperty
@JsonIgnore
private AzureChatEnhancementConfiguration enhancements;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -259,6 +293,30 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) {
return this;
}

public Builder withSeed(Long seed) {
Assert.notNull(seed, "seed must not be null");
this.options.seed = seed;
return this;
}

public Builder withLogprobs(Boolean logprobs) {
Assert.notNull(logprobs, "logprobs must not be null");
this.options.logprobs = logprobs;
return this;
}

public Builder withTopLogprobs(Integer topLogprobs) {
Assert.notNull(topLogprobs, "topLogprobs must not be null");
this.options.topLogProbs = topLogprobs;
return this;
}

public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) {
Assert.notNull(enhancements, "enhancements must not be null");
this.options.enhancements = enhancements;
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -404,6 +462,38 @@ public Integer getTopK() {
return null;
}

public Long getSeed() {
return this.seed;
}

public void setSeed(Long seed) {
this.seed = seed;
}

public Boolean isLogprobs() {
return this.logprobs;
}

public void setLogprobs(Boolean logprobs) {
this.logprobs = logprobs;
}

public Integer getTopLogProbs() {
return this.topLogProbs;
}

public void setTopLogProbs(Integer topLogProbs) {
this.topLogProbs = topLogProbs;
}

public AzureChatEnhancementConfiguration getEnhancements() {
return this.enhancements;
}

public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
this.enhancements = enhancements;
}

@Override
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
Expand Down Expand Up @@ -432,6 +522,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.withResponseFormat(fromOptions.getResponseFormat())
.withSeed(fromOptions.getSeed())
.withLogprobs(fromOptions.isLogprobs())
.withTopLogprobs(fromOptions.getTopLogProbs())
.withEnhancements(fromOptions.getEnhancements())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,9 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import org.junit.jupiter.api.Test;
Expand All @@ -34,6 +37,7 @@

/**
* @author Christian Tzolov
* @author Soby Chacko
*/
public class AzureChatCompletionsOptionsTests {

Expand All @@ -42,6 +46,9 @@ public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);

AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);

var defaultOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
.withTemperature(66.6)
Expand All @@ -53,6 +60,10 @@ public void createRequestWithChatOptions() {
.withStop(List.of("foo", "bar"))
.withTopP(0.69)
.withUser("user")
.withSeed(123L)
.withLogprobs(true)
.withTopLogprobs(5)
.withEnhancements(mockAzureChatEnhancementConfiguration)
.withResponseFormat(AzureOpenAiResponseFormat.TEXT)
.build();

Expand All @@ -72,8 +83,15 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
assertThat(requestOptions.getTopP()).isEqualTo(0.69);
assertThat(requestOptions.getUser()).isEqualTo("user");
assertThat(requestOptions.getSeed()).isEqualTo(123L);
assertThat(requestOptions.isLogprobs()).isTrue();
assertThat(requestOptions.getTopLogprobs()).isEqualTo(5);
assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration);
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class);

AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);

var runtimeOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("PROMPT_MODEL")
.withTemperature(99.9)
Expand All @@ -85,6 +103,10 @@ public void createRequestWithChatOptions() {
.withStop(List.of("foo", "bar"))
.withTopP(0.111)
.withUser("user2")
.withSeed(1234L)
.withLogprobs(true)
.withTopLogprobs(4)
.withEnhancements(anotherMockAzureChatEnhancementConfiguration)
.withResponseFormat(AzureOpenAiResponseFormat.JSON)
.build();

Expand All @@ -102,6 +124,10 @@ public void createRequestWithChatOptions() {
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
assertThat(requestOptions.getTopP()).isEqualTo(0.111);
assertThat(requestOptions.getUser()).isEqualTo("user2");
assertThat(requestOptions.getSeed()).isEqualTo(1234L);
assertThat(requestOptions.isLogprobs()).isTrue();
assertThat(requestOptions.getTopLogprobs()).isEqualTo(4);
assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration);
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,12 +936,28 @@ public record TopLogProbs(// @formatter:off
* @param promptTokens Number of tokens in the prompt.
* @param totalTokens Total number of tokens used in the request (prompt +
* completion).
* @param completionTokenDetails Breakdown of tokens used in a completion
*/
@JsonInclude(Include.NON_NULL)
public record Usage(// @formatter:off
@JsonProperty("completion_tokens") Integer completionTokens,
@JsonProperty("prompt_tokens") Integer promptTokens,
@JsonProperty("total_tokens") Integer totalTokens) {// @formatter:on
@JsonProperty("total_tokens") Integer totalTokens,
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on

public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
this(completionTokens, promptTokens, totalTokens, null);
}

/**
* Breakdown of tokens used in a completion
*
* @param reasoningTokens Number of tokens generated by the model for reasoning.
*/
@JsonInclude(Include.NON_NULL)
public record CompletionTokenDetails(// @formatter:off
@JsonProperty("reasoning_tokens") Integer reasoningTokens) {// @formatter:on
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ public Long getGenerationTokens() {
return generationTokens != null ? generationTokens.longValue() : 0;
}

public Long getReasoningTokens() {
OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails();
Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null;
return reasoningTokens != null ? reasoningTokens.longValue() : 0;
}

@Override
public Long getTotalTokens() {
Integer totalTokens = getUsage().totalTokens();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,28 @@ void whenTotalTokensIsNull() {
assertThat(usage.getTotalTokens()).isEqualTo(300);
}

@Test
void whenCompletionTokenDetailsIsNull() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null);
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getTotalTokens()).isEqualTo(300);
assertThat(usage.getReasoningTokens()).isEqualTo(0);
}

@Test
void whenReasoningTokensIsNull() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
new OpenAiApi.Usage.CompletionTokenDetails(null));
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getReasoningTokens()).isEqualTo(0);
}

@Test
void whenCompletionTokenDetailsIsPresent() {
OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300,
new OpenAiApi.Usage.CompletionTokenDetails(50));
OpenAiUsage usage = OpenAiUsage.from(openAiUsage);
assertThat(usage.getReasoningTokens()).isEqualTo(50);
}

}
Loading

0 comments on commit 36d62c6

Please sign in to comment.