Skip to content

Commit

Permalink
✨ support and identify Functions<I, Void> or Consumer<I> to avoid sec…
Browse files Browse the repository at this point in the history
…ond round trip
  • Loading branch information
Grogdunn committed Apr 30, 2024
1 parent 493e2ea commit d57914f
Show file tree
Hide file tree
Showing 19 changed files with 173 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {
boolean completeRoundTrip = true;
if (prompt.getOptions() instanceof AnthropicChatOptions anthropicChatOptions) {
completeRoundTrip = anthropicChatOptions.isCompleteRoundTrip();
}
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
return toChatResponse(completionEntity.getBody());
});
}
Expand Down Expand Up @@ -395,9 +391,10 @@ public ChatCompletion build() {
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {

protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
ChatCompletionRequest previousRequest, RequestMessage responseMessage,
List<RequestMessage> conversationHistory) {
boolean needCompleteRoundTrip = false;
List<MediaContent> toolToUseList = responseMessage.content()
.stream()
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
Expand All @@ -417,16 +414,19 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques

String functionResponse = this.functionCallbackRegister.get(functionName)
.call(ModelOptionsUtils.toJsonString(functionArguments));

toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
if (functionResponse != null) {
needCompleteRoundTrip = true;
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
}
}

// Add the function response to the conversation.
conversationHistory.add(new RequestMessage(toolResults, Role.USER));

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
return ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
final var build = ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
return new CompleteRoundTripBox<>(needCompleteRoundTrip, build);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
@JsonIgnore
private Set<String> functions = new HashSet<>();

@JsonIgnore
private boolean completeRoundTrip = true;
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -140,11 +138,6 @@ public Builder withFunction(String functionName) {
return this;
}

public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
this.options.completeRoundTrip = completeRoundTrip;
return this;
}

public AnthropicChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -231,14 +224,4 @@ public void setFunctions(Set<String> functions) {
this.functions = functions;
}

@Override
public boolean isCompleteRoundTrip() {
return completeRoundTrip;
}

@Override
public void setCompleteRoundTrip(boolean completeRoundTrip) {
this.completeRoundTrip = completeRoundTrip;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package org.springframework.ai.azure.openai;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
Expand Down Expand Up @@ -134,11 +131,7 @@ public ChatResponse call(Prompt prompt) {
options.setStream(false);

logger.trace("Azure ChatCompletionsOptions: {}", options);
boolean completeRoundTrip = true;
if (prompt.getOptions() instanceof AzureOpenAiChatOptions azureOpenAiChatOptions) {
completeRoundTrip = azureOpenAiChatOptions.isCompleteRoundTrip();
}
ChatCompletions chatCompletions = this.callWithFunctionSupport(options, completeRoundTrip);
ChatCompletions chatCompletions = this.callWithFunctionSupport(options);
logger.trace("Azure ChatCompletions: {}", chatCompletions);

List<Generation> generations = chatCompletions.getChoices()
Expand Down Expand Up @@ -431,9 +424,11 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
}

@Override
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseRequest(
ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage,
List<ChatRequestMessage> conversationHistory) {

boolean needCompleteRoundTrip = false;
// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) {
Expand All @@ -447,8 +442,11 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

// Add the function response to the conversation.
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
if (functionResponse != null) {
needCompleteRoundTrip = true;
// Add the function response to the conversation.
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
}
}

// Recursively call chatCompletionWithTools until the model doesn't call a
Expand All @@ -457,7 +455,7 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti

newRequest = merge(previousRequest, newRequest);

return newRequest;
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
@JsonIgnore
private Set<String> functions = new HashSet<>();

@JsonIgnore
private boolean completeRoundTrip = true;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -242,11 +239,6 @@ public Builder withFunction(String functionName) {
return this;
}

public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
this.options.completeRoundTrip = completeRoundTrip;
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -364,14 +356,4 @@ public void setFunctions(Set<String> functions) {
this.functions = functions;
}

@Override
public boolean isCompleteRoundTrip() {
return completeRoundTrip;
}

@Override
public void setCompleteRoundTrip(boolean completeRoundTrip) {
this.completeRoundTrip = completeRoundTrip;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
Expand Down Expand Up @@ -50,6 +52,8 @@ class AzureOpenAiChatClientFunctionCallIT {
@Autowired
private AzureOpenAiChatClient chatClient;

@Autowired
private String selectedModel;
@Test
void functionCallTest() {

Expand All @@ -58,7 +62,7 @@ void functionCallTest() {
List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("gpt-4-0125-preview")
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
Expand All @@ -84,13 +88,11 @@ void functionCallWithoutCompleteRoundTrip() {

final var spyingMockWeatherService = new SpyingMockWeatherService();
var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName("gpt-4-0125-preview")
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(spyingMockWeatherService)
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.withCompleteRoundTrip(false)
.build();

ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
Expand All @@ -111,12 +113,14 @@ public OpenAIClient openAIClient() {
}

@Bean
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient, String selectedModel) {
return new AzureOpenAiChatClient(openAIClient,
AzureOpenAiChatOptions.builder()
.withDeploymentName("gpt-4-0125-preview")
.withMaxTokens(500)
.build());
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
}

@Bean
public String selectedModel() {
return Optional.ofNullable(System.getenv("AZURE_OPENAI_MODEL")).orElse("gpt-4-0125-preview");
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@

import java.util.function.Function;

public class SpyingMockWeatherService implements Function<MockWeatherService.Request, MockWeatherService.Response> {

private final MockWeatherService inner = new MockWeatherService();
public class SpyingMockWeatherService implements Function<MockWeatherService.Request, Void> {

private MockWeatherService.Request interceptedRequest = null;

@Override
public MockWeatherService.Response apply(MockWeatherService.Request request) {
public Void apply(MockWeatherService.Request request) {
interceptedRequest = request;
return inner.apply(request);
return null;
}

public MockWeatherService.Request getInterceptedRequest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ public ChatResponse call(Prompt prompt) {
var request = createRequest(prompt, false);

return retryTemplate.execute(ctx -> {
boolean completeRoundTrip = true;
if (prompt.getOptions() instanceof MistralAiChatOptions mistralAiChatOptions) {
completeRoundTrip = mistralAiChatOptions.isCompleteRoundTrip();
}
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
Expand Down Expand Up @@ -152,12 +148,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

return completionChunks.map(chunk -> toChatCompletion(chunk)).map(chatCompletion -> {
boolean completeRoundTrip = true;
if (prompt.getOptions() instanceof MistralAiChatOptions mistralAiChatOptions) {
completeRoundTrip = mistralAiChatOptions.isCompleteRoundTrip();
}
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)),
completeRoundTrip)
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))
.getBody();

@SuppressWarnings("null")
Expand Down Expand Up @@ -255,9 +246,10 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
// Function Calling Support
//
@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
List<ChatCompletionMessage> conversationHistory) {
boolean needCompleteRoundTrip = false;
// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ToolCall toolCall : responseMessage.toolCalls()) {
Expand All @@ -270,18 +262,21 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
}

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
if (functionResponse != null) {
needCompleteRoundTrip = true;
// Add the function response to the conversation.
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
functionName, null));
}

// Add the function response to the conversation.
conversationHistory
.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null));
}

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);

return newRequest;
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
@JsonIgnore
private Set<String> functions = new HashSet<>();

@JsonIgnore
private boolean completeRoundTrip = true;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -199,11 +196,6 @@ public Builder withFunction(String functionName) {
return this;
}

public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
this.options.completeRoundTrip = completeRoundTrip;
return this;
}

public MistralAiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -317,14 +309,4 @@ public void setFunctions(Set<String> functions) {
this.functions = functions;
}

@Override
public boolean isCompleteRoundTrip() {
return completeRoundTrip;
}

@Override
public void setCompleteRoundTrip(boolean completeRoundTrip) {
this.completeRoundTrip = completeRoundTrip;
}

}
Loading

0 comments on commit d57914f

Please sign in to comment.