Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion core/src/main/java/com/google/adk/agents/RunConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.adk.agents;

import com.google.adk.tools.MissingToolResolutionStrategy;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand Down Expand Up @@ -70,6 +71,8 @@ public enum ToolExecutionMode {

public abstract int maxLlmCalls();

public abstract MissingToolResolutionStrategy missingToolResolutionStrategy();

public abstract Builder toBuilder();

public static Builder builder() {
Expand All @@ -78,6 +81,7 @@ public static Builder builder() {
.setResponseModalities(ImmutableList.of())
.setStreamingMode(StreamingMode.NONE)
.setToolExecutionMode(ToolExecutionMode.NONE)
.setMissingToolResolutionStrategy(MissingToolResolutionStrategy.THROW_EXCEPTION)
.setMaxLlmCalls(500);
}

Expand All @@ -90,7 +94,8 @@ public static Builder builder(RunConfig runConfig) {
.setResponseModalities(runConfig.responseModalities())
.setSpeechConfig(runConfig.speechConfig())
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
.setInputAudioTranscription(runConfig.inputAudioTranscription());
.setInputAudioTranscription(runConfig.inputAudioTranscription())
.setMissingToolResolutionStrategy(runConfig.missingToolResolutionStrategy());
}

/** Builder for {@link RunConfig}. */
Expand Down Expand Up @@ -123,6 +128,10 @@ public abstract Builder setInputAudioTranscription(
@CanIgnoreReturnValue
public abstract Builder setMaxLlmCalls(int maxLlmCalls);

@CanIgnoreReturnValue
public abstract Builder setMissingToolResolutionStrategy(
MissingToolResolutionStrategy missingToolResolutionStrategy);

abstract RunConfig autoBuild();

public RunConfig build() {
Expand Down
43 changes: 31 additions & 12 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import com.google.adk.events.EventActions;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.FunctionTool;
import com.google.adk.tools.MissingToolResolutionStrategy;
import com.google.adk.tools.ToolConfirmation;
import com.google.adk.tools.ToolContext;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
Expand Down Expand Up @@ -137,10 +137,16 @@ public static Maybe<Event> handleFunctionCalls(
Map<String, BaseTool> tools,
Map<String, ToolConfirmation> toolConfirmations) {
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();

MissingToolResolutionStrategy missingToolResolutionStrategy =
invocationContext.runConfig().missingToolResolutionStrategy();
ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
ImmutableList.Builder<FunctionCall> validCalls = ImmutableList.builder();
for (FunctionCall functionCall : functionCalls) {
if (!tools.containsKey(functionCall.name().get())) {
throw new VerifyException("Tool not found: " + functionCall.name().get());
missingTools.add(
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
} else {
validCalls.add(functionCall);
}
}

Expand Down Expand Up @@ -202,12 +208,16 @@ public static Maybe<Event> handleFunctionCalls(
Flowable<Event> functionResponseEventsFlowable;
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
functionResponseEventsFlowable =
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
Flowable.fromIterable(validCalls.build()).concatMapMaybe(functionCallMapper);
} else {
functionResponseEventsFlowable =
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
Flowable.fromIterable(validCalls.build()).flatMapMaybe(functionCallMapper);
}
return functionResponseEventsFlowable
Flowable<Event> missingToolsFlowable =
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
Flowable<Event> allEventsFlowable =
Flowable.concat(missingToolsFlowable, functionResponseEventsFlowable);
return allEventsFlowable
.toList()
.flatMapMaybe(
events -> {
Expand Down Expand Up @@ -242,13 +252,19 @@ public static Maybe<Event> handleFunctionCalls(
public static Maybe<Event> handleFunctionCallsLive(
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
MissingToolResolutionStrategy missingToolResolutionStrategy =
invocationContext.runConfig().missingToolResolutionStrategy();

ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
ImmutableList.Builder<FunctionCall> validCalls = ImmutableList.builder();
for (FunctionCall functionCall : functionCalls) {
if (!tools.containsKey(functionCall.name().get())) {
throw new VerifyException("Tool not found: " + functionCall.name().get());
missingTools.add(
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
} else {
validCalls.add(functionCall);
}
}

Function<FunctionCall, Maybe<Event>> functionCallMapper =
functionCall -> {
BaseTool tool = tools.get(functionCall.name().get());
Expand Down Expand Up @@ -311,17 +327,20 @@ public static Maybe<Event> handleFunctionCallsLive(
};

Flowable<Event> responseEventsFlowable;

if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
responseEventsFlowable =
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
Flowable.fromIterable(validCalls.build()).concatMapMaybe(functionCallMapper);

} else {
responseEventsFlowable =
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
Flowable.fromIterable(validCalls.build()).flatMapMaybe(functionCallMapper);
}
Flowable<Event> missingToolsFlowable =
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
Flowable<Event> allEventsFlowable =
Flowable.concat(missingToolsFlowable, responseEventsFlowable);

return responseEventsFlowable
return allEventsFlowable
.toList()
.flatMapMaybe(
events -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.google.adk.tools;

import com.google.adk.agents.InvocationContext;
import com.google.adk.events.Event;
import com.google.common.base.VerifyException;
import com.google.genai.types.FunctionCall;
import io.reactivex.rxjava3.core.Maybe;
import java.util.function.BiFunction;

public interface MissingToolResolutionStrategy {
public static final MissingToolResolutionStrategy THROW_EXCEPTION =
new MissingToolResolutionStrategy() {
@Override
public Maybe<Event> onMissingTool(
InvocationContext invocationContext, FunctionCall functionCall) {
throw new VerifyException(
"Tool not found: " + functionCall.name().orElse(functionCall.toJson()));
}
};

public static final MissingToolResolutionStrategy RETURN_ERROR =
new MissingToolResolutionStrategy() {
@Override
public Maybe<Event> onMissingTool(
InvocationContext invocationContext, FunctionCall functionCall) {
return Maybe.error(
new VerifyException(
"Tool not found: " + functionCall.name().orElse(functionCall.toJson())));
}
};

public static final MissingToolResolutionStrategy IGNORE =
new MissingToolResolutionStrategy() {
@Override
public Maybe<Event> onMissingTool(
InvocationContext invocationContext, FunctionCall functionCall) {
return Maybe.empty();
}
};

public static MissingToolResolutionStrategy respondWithEvent(
BiFunction<InvocationContext, FunctionCall, Maybe<Event>> eventFactory) {
return new MissingToolResolutionStrategy() {
@Override
public Maybe<Event> onMissingTool(
InvocationContext invocationContext, FunctionCall functionCall) {
return eventFactory.apply(invocationContext, functionCall);
}
};
}

public static MissingToolResolutionStrategy respondWithEventSync(
BiFunction<InvocationContext, FunctionCall, Event> eventFactory) {
return respondWithEvent(
(invocationContext, functionCall) ->
Maybe.just(eventFactory.apply(invocationContext, functionCall)));
}

Maybe<Event> onMissingTool(InvocationContext invocationContext, FunctionCall functionCall);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import static org.junit.Assert.assertThrows;

import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.RunConfig;
import com.google.adk.events.Event;
import com.google.adk.testing.TestUtils;
import com.google.adk.tools.MissingToolResolutionStrategy;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
Expand Down Expand Up @@ -67,6 +69,37 @@ public void handleFunctionCalls_missingTool() {
invocationContext, event, /* tools= */ ImmutableMap.of()));
}

@Test
public void handleFunctionCalls_missingTool_recoveryStrategy() {
InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder()
.setMissingToolResolutionStrategy(
MissingToolResolutionStrategy.respondWithEventSync(
(ctx, call) ->
Event.builder()
.content(
Content.fromParts(
Part.fromText("tool missing: " + call.name().get())))
.build()))
.build());
Event event =
createEvent("event").toBuilder()
.content(
Content.fromParts(
Part.fromText("..."), Part.fromFunctionCall("missing_tool", ImmutableMap.of())))
.build();

Event functionResponseEvent =
Functions.handleFunctionCalls(invocationContext, event, /* tools= */ ImmutableMap.of())
.blockingGet();

assertThat(functionResponseEvent).isNotNull();
assertThat(functionResponseEvent.content().get().parts().get())
.containsExactly(Part.fromText("tool missing: missing_tool"));
}

@Test
public void handleFunctionCalls_singleFunctionCall() {
InvocationContext invocationContext = createInvocationContext(createRootAgent());
Expand Down