Skip to content

Commit

Permalink
Use Double instead of Float for portable ChatOptions
Browse files Browse the repository at this point in the history
This change updates the type of portable chat options from Float to
Double. Affected options include:
- frequencyPenalty
- presencePenalty
- temperature
- topP

The motivation for this change is to simplify coding. In Java, Float
values require an "f" suffix (e.g., 0.5f), while Double values don't
need any suffix. This makes Double easier to type and reduces
potential errors from forgetting the "f" suffix.

APIs, tests, and documentation have been updated to reflect this
change.

Fixes gh-712

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and Mark Pollack committed Sep 16, 2024
1 parent 40714c9 commit 4b123a7
Show file tree
Hide file tree
Showing 146 changed files with 731 additions and 717 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM

public static final Integer DEFAULT_MAX_TOKENS = 500;

public static final Float DEFAULT_TEMPERATURE = 0.8f;
public static final Double DEFAULT_TEMPERATURE = 0.8;

/**
* The lower-level API for the Anthropic service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
private @JsonProperty("max_tokens") Integer maxTokens;
private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata;
private @JsonProperty("stop_sequences") List<String> stopSequences;
private @JsonProperty("temperature") Float temperature;
private @JsonProperty("top_p") Float topP;
private @JsonProperty("temperature") Double temperature;
private @JsonProperty("top_p") Double topP;
private @JsonProperty("top_k") Integer topK;

/**
Expand Down Expand Up @@ -112,12 +112,12 @@ public Builder withStopSequences(List<String> stopSequences) {
return this;
}

public Builder withTemperature(Float temperature) {
public Builder withTemperature(Double temperature) {
this.options.temperature = temperature;
return this;
}

public Builder withTopP(Float topP) {
public Builder withTopP(Double topP) {
this.options.topP = topP;
return this;
}
Expand Down Expand Up @@ -186,20 +186,20 @@ public void setStopSequences(List<String> stopSequences) {
}

@Override
public Float getTemperature() {
public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Float temperature) {
public void setTemperature(Double temperature) {
this.temperature = temperature;
}

@Override
public Float getTopP() {
public Double getTopP() {
return this.topP;
}

public void setTopP(Float topP) {
public void setTopP(Double topP) {
this.topP = topP;
}

Expand Down Expand Up @@ -236,13 +236,13 @@ public void setFunctions(Set<String> functions) {

@Override
@JsonIgnore
public Float getFrequencyPenalty() {
public Double getFrequencyPenalty() {
return null;
}

@Override
@JsonIgnore
public Float getPresencePenalty() {
public Double getPresencePenalty() {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,19 @@ public record ChatCompletionRequest( // @formatter:off
@JsonProperty("metadata") Metadata metadata,
@JsonProperty("stop_sequences") List<String> stopSequences,
@JsonProperty("stream") Boolean stream,
@JsonProperty("temperature") Float temperature,
@JsonProperty("top_p") Float topP,
@JsonProperty("temperature") Double temperature,
@JsonProperty("top_p") Double topP,
@JsonProperty("top_k") Integer topK,
@JsonProperty("tools") List<Tool> tools) {
// @formatter:on

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, String system, Integer maxTokens,
Float temperature, Boolean stream) {
Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null);
}

public ChatCompletionRequest(String model, List<AnthropicMessage> messages, String system, Integer maxTokens,
List<String> stopSequences, Float temperature, Boolean stream) {
List<String> stopSequences, Double temperature, Boolean stream) {
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null);
}

Expand Down Expand Up @@ -292,9 +292,9 @@ public static class ChatCompletionRequestBuilder {

private Boolean stream = false;

private Float temperature;
private Double temperature;

private Float topP;
private Double topP;

private Integer topK;

Expand Down Expand Up @@ -357,12 +357,12 @@ public ChatCompletionRequestBuilder withStream(Boolean stream) {
return this;
}

public ChatCompletionRequestBuilder withTemperature(Float temperature) {
public ChatCompletionRequestBuilder withTemperature(Double temperature) {
this.temperature = temperature;
return this;
}

public ChatCompletionRequestBuilder withTopP(Float topP) {
public ChatCompletionRequestBuilder withTopP(Double topP) {
this.topP = topP;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void testMessageHistory() {

@Test
void streamingWithTokenUsage() {
var promptOptions = AnthropicChatOptions.builder().withTemperature(0f).build();
var promptOptions = AnthropicChatOptions.builder().withTemperature(0.0).build();

var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ void observationForChatOperation() {
.withModel(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue())
.withMaxTokens(2048)
.withStopSequences(List.of("this-is-the-end"))
.withTemperature(0.7f)
.withTemperature(0.7)
.withTopK(1)
.withTopP(1f)
.withTopP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand All @@ -93,9 +93,9 @@ void observationForStreamingChatOperation() {
.withModel(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue())
.withMaxTokens(2048)
.withStopSequences(List.of("this-is-the-end"))
.withTemperature(0.7f)
.withTemperature(0.7)
.withTopK(1)
.withTopP(1f)
.withTopP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ public class ChatCompletionRequestTests {
public void createRequestWithChatOptions() {

var client = new AnthropicChatModel(new AnthropicApi("TEST"),
AnthropicChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
AnthropicChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build());

var request = client.createRequest(new Prompt("Test message content"), false);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();

assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.temperature()).isEqualTo(66.6f);
assertThat(request.temperature()).isEqualTo(66.6);

request = client.createRequest(new Prompt("Test message content",
AnthropicChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9f).build()), true);
AnthropicChatOptions.builder().withModel("PROMPT_MODEL").withTemperature(99.9).build()), true);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();

assertThat(request.model()).isEqualTo("PROMPT_MODEL");
assertThat(request.temperature()).isEqualTo(99.9f);
assertThat(request.temperature()).isEqualTo(99.9);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void chatCompletionEntity() {
Role.USER);
ResponseEntity<ChatCompletionResponse> response = anthropicApi
.chatCompletionEntity(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(),
List.of(chatCompletionMessage), null, 100, 0.8f, false));
List.of(chatCompletionMessage), null, 100, 0.8, false));

System.out.println(response);
assertThat(response).isNotNull();
Expand All @@ -58,9 +58,8 @@ void chatCompletionStream() {
AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")),
Role.USER);

Flux<ChatCompletionResponse> response = anthropicApi
.chatCompletionStream(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(),
List.of(chatCompletionMessage), null, 100, 0.8f, true));
Flux<ChatCompletionResponse> response = anthropicApi.chatCompletionStream(new ChatCompletionRequest(
AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, true));

assertThat(response).isNotNull();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ void toolCalls() {
Role.USER);

ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(
AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), systemPrompt, 500,
0.8f, false);
AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(), List.of(chatCompletionMessage), systemPrompt, 500, 0.8,
false);

ResponseEntity<ChatCompletionResponse> chatCompletion = doCall(chatCompletionRequest);

Expand Down Expand Up @@ -147,7 +147,7 @@ private ResponseEntity<ChatCompletionResponse> doCall(ChatCompletionRequest chat
AnthropicMessage chatCompletionMessage2 = new AnthropicMessage(List.of(new ContentBlock(content)), Role.USER);

return doCall(new ChatCompletionRequest(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue(),
List.of(chatCompletionMessage2), null, 500, 0.8f, false));
List.of(chatCompletionMessage2), null, 500, 0.8, false));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private ResponseEntity<ChatCompletionResponse> doCall(List<AnthropicMessage> mes
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS)
.withMessages(messageConversation)
.withMaxTokens(1500)
.withTemperature(0.8f)
.withTemperature(0.8)
.withTools(tools)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
* @author Christian Tzolov
* @author Grogdunn
* @author Benoit Moussaud
* @author Thomas Vitale
* @author luocongqiu
* @author timostark
* @see ChatModel
Expand All @@ -98,7 +99,7 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha

private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-4o";

private static final Float DEFAULT_TEMPERATURE = 0.7f;
private static final Double DEFAULT_TEMPERATURE = 0.7;

/**
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
Expand Down Expand Up @@ -422,22 +423,22 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,

mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature());
if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) {
mergedAzureOptions.setTemperature(toSpringAiOptions.getTemperature().doubleValue());
mergedAzureOptions.setTemperature(toSpringAiOptions.getTemperature());
}

mergedAzureOptions.setTopP(fromAzureOptions.getTopP());
if (mergedAzureOptions.getTopP() == null && toSpringAiOptions.getTopP() != null) {
mergedAzureOptions.setTopP(toSpringAiOptions.getTopP().doubleValue());
mergedAzureOptions.setTopP(toSpringAiOptions.getTopP());
}

mergedAzureOptions.setFrequencyPenalty(fromAzureOptions.getFrequencyPenalty());
if (mergedAzureOptions.getFrequencyPenalty() == null && toSpringAiOptions.getFrequencyPenalty() != null) {
mergedAzureOptions.setFrequencyPenalty(toSpringAiOptions.getFrequencyPenalty().doubleValue());
mergedAzureOptions.setFrequencyPenalty(toSpringAiOptions.getFrequencyPenalty());
}

mergedAzureOptions.setPresencePenalty(fromAzureOptions.getPresencePenalty());
if (mergedAzureOptions.getPresencePenalty() == null && toSpringAiOptions.getPresencePenalty() != null) {
mergedAzureOptions.setPresencePenalty(toSpringAiOptions.getPresencePenalty().doubleValue());
mergedAzureOptions.setPresencePenalty(toSpringAiOptions.getPresencePenalty());
}

mergedAzureOptions.setResponseFormat(fromAzureOptions.getResponseFormat());
Expand Down Expand Up @@ -486,19 +487,19 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
}

if (fromSpringAiOptions.getTemperature() != null) {
mergedAzureOptions.setTemperature(fromSpringAiOptions.getTemperature().doubleValue());
mergedAzureOptions.setTemperature(fromSpringAiOptions.getTemperature());
}

if (fromSpringAiOptions.getTopP() != null) {
mergedAzureOptions.setTopP(fromSpringAiOptions.getTopP().doubleValue());
mergedAzureOptions.setTopP(fromSpringAiOptions.getTopP());
}

if (fromSpringAiOptions.getFrequencyPenalty() != null) {
mergedAzureOptions.setFrequencyPenalty(fromSpringAiOptions.getFrequencyPenalty().doubleValue());
mergedAzureOptions.setFrequencyPenalty(fromSpringAiOptions.getFrequencyPenalty());
}

if (fromSpringAiOptions.getPresencePenalty() != null) {
mergedAzureOptions.setPresencePenalty(fromSpringAiOptions.getPresencePenalty().doubleValue());
mergedAzureOptions.setPresencePenalty(fromSpringAiOptions.getPresencePenalty());
}

if (fromSpringAiOptions.getN() != null) {
Expand Down
Loading

0 comments on commit 4b123a7

Please sign in to comment.