diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 2f8e417d..5b68ca36 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -16,6 +16,7 @@ package com.google.adk.agents; +import com.google.adk.tools.MissingToolResolutionStrategy; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -70,6 +71,8 @@ public enum ToolExecutionMode { public abstract int maxLlmCalls(); + public abstract MissingToolResolutionStrategy missingToolResolutionStrategy(); + public abstract Builder toBuilder(); public static Builder builder() { @@ -78,6 +81,7 @@ public static Builder builder() { .setResponseModalities(ImmutableList.of()) .setStreamingMode(StreamingMode.NONE) .setToolExecutionMode(ToolExecutionMode.NONE) + .setMissingToolResolutionStrategy(MissingToolResolutionStrategy.THROW_EXCEPTION) .setMaxLlmCalls(500); } @@ -90,7 +94,8 @@ public static Builder builder(RunConfig runConfig) { .setResponseModalities(runConfig.responseModalities()) .setSpeechConfig(runConfig.speechConfig()) .setOutputAudioTranscription(runConfig.outputAudioTranscription()) - .setInputAudioTranscription(runConfig.inputAudioTranscription()); + .setInputAudioTranscription(runConfig.inputAudioTranscription()) + .setMissingToolResolutionStrategy(runConfig.missingToolResolutionStrategy()); } /** Builder for {@link RunConfig}. */ @@ -123,6 +128,10 @@ public abstract Builder setInputAudioTranscription( @CanIgnoreReturnValue public abstract Builder setMaxLlmCalls(int maxLlmCalls); + @CanIgnoreReturnValue + public abstract Builder setMissingToolResolutionStrategy( + MissingToolResolutionStrategy missingToolResolutionStrategy); + abstract RunConfig autoBuild(); public RunConfig build() { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index a952d602..bc8fbd82 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -30,9 +30,9 @@ import com.google.adk.events.EventActions; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; +import com.google.adk.tools.MissingToolResolutionStrategy; import com.google.adk.tools.ToolConfirmation; import com.google.adk.tools.ToolContext; -import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -72,6 +72,84 @@ public static String generateClientFunctionCallId() { return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID(); } + /** Container for separated valid and missing tool calls. */ + private static class ToolCallSeparation { + private final ImmutableList validCalls; + private final Flowable missingToolsFlowable; + + ToolCallSeparation( + ImmutableList validCalls, Flowable missingToolsFlowable) { + this.validCalls = validCalls; + this.missingToolsFlowable = missingToolsFlowable; + } + + ImmutableList validCalls() { + return validCalls; + } + + Flowable missingToolsFlowable() { + return missingToolsFlowable; + } + } + + /** + * Separates function calls into valid calls and missing tool events. + * + * @param invocationContext The invocation context. + * @param functionCalls The list of function calls to separate. + * @param tools The available tools. + * @return A ToolCallSeparation containing valid calls and a flowable for missing tools. + */ + private static ToolCallSeparation separateValidAndMissingToolCalls( + InvocationContext invocationContext, + ImmutableList functionCalls, + Map tools) { + MissingToolResolutionStrategy missingToolResolutionStrategy = + invocationContext.runConfig().missingToolResolutionStrategy(); + ImmutableList.Builder> missingTools = ImmutableList.builder(); + ImmutableList.Builder validCalls = ImmutableList.builder(); + + for (FunctionCall functionCall : functionCalls) { + if (!tools.containsKey(functionCall.name().get())) { + missingTools.add( + missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall)); + } else { + validCalls.add(functionCall); + } + } + + Flowable missingToolsFlowable = + Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe); + + return new ToolCallSeparation(validCalls.build(), missingToolsFlowable); + } + + /** + * Creates a combined flowable of function response events based on execution mode. + * + * @param invocationContext The invocation context. + * @param validCalls The list of valid function calls. + * @param missingToolsFlowable The flowable for missing tool events. + * @param functionCallMapper The mapper to convert function calls to events. + * @return A combined flowable of all events. + */ + private static Flowable createCombinedFlowable( + InvocationContext invocationContext, + ImmutableList validCalls, + Flowable missingToolsFlowable, + Function> functionCallMapper) { + Flowable functionResponseEventsFlowable; + if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { + functionResponseEventsFlowable = + Flowable.fromIterable(validCalls).concatMapMaybe(functionCallMapper); + } else { + functionResponseEventsFlowable = + Flowable.fromIterable(validCalls).flatMapMaybe(functionCallMapper); + } + + return Flowable.concat(missingToolsFlowable, functionResponseEventsFlowable); + } + /** * Populates missing function call IDs in the provided event's content. * @@ -137,12 +215,8 @@ public static Maybe handleFunctionCalls( Map tools, Map toolConfirmations) { ImmutableList functionCalls = functionCallEvent.functionCalls(); - - for (FunctionCall functionCall : functionCalls) { - if (!tools.containsKey(functionCall.name().get())) { - throw new VerifyException("Tool not found: " + functionCall.name().get()); - } - } + ToolCallSeparation separation = + separateValidAndMissingToolCalls(invocationContext, functionCalls, tools); Function> functionCallMapper = functionCall -> { @@ -199,15 +273,13 @@ public static Maybe handleFunctionCalls( }); }; - Flowable functionResponseEventsFlowable; - if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { - functionResponseEventsFlowable = - Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); - } else { - functionResponseEventsFlowable = - Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper); - } - return functionResponseEventsFlowable + Flowable allEventsFlowable = + createCombinedFlowable( + invocationContext, + separation.validCalls(), + separation.missingToolsFlowable(), + functionCallMapper); + return allEventsFlowable .toList() .flatMapMaybe( events -> { @@ -242,12 +314,8 @@ public static Maybe handleFunctionCalls( public static Maybe handleFunctionCallsLive( InvocationContext invocationContext, Event functionCallEvent, Map tools) { ImmutableList functionCalls = functionCallEvent.functionCalls(); - - for (FunctionCall functionCall : functionCalls) { - if (!tools.containsKey(functionCall.name().get())) { - throw new VerifyException("Tool not found: " + functionCall.name().get()); - } - } + ToolCallSeparation separation = + separateValidAndMissingToolCalls(invocationContext, functionCalls, tools); Function> functionCallMapper = functionCall -> { @@ -310,18 +378,14 @@ public static Maybe handleFunctionCallsLive( }); }; - Flowable responseEventsFlowable; - - if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { - responseEventsFlowable = - Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); - - } else { - responseEventsFlowable = - Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper); - } + Flowable allEventsFlowable = + createCombinedFlowable( + invocationContext, + separation.validCalls(), + separation.missingToolsFlowable(), + functionCallMapper); - return responseEventsFlowable + return allEventsFlowable .toList() .flatMapMaybe( events -> { diff --git a/core/src/main/java/com/google/adk/tools/MissingToolResolutionStrategy.java b/core/src/main/java/com/google/adk/tools/MissingToolResolutionStrategy.java new file mode 100644 index 00000000..b33297b2 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/MissingToolResolutionStrategy.java @@ -0,0 +1,39 @@ +package com.google.adk.tools; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.common.base.VerifyException; +import com.google.genai.types.FunctionCall; +import io.reactivex.rxjava3.core.Maybe; +import java.util.function.BiFunction; + +public interface MissingToolResolutionStrategy { + public static final MissingToolResolutionStrategy THROW_EXCEPTION = + (invocationContext, functionCall) -> { + throw new VerifyException( + "Tool not found: " + functionCall.name().orElse(functionCall.toJson())); + }; + + public static final MissingToolResolutionStrategy RETURN_ERROR = + (invocationContext, functionCall) -> + Maybe.error( + new VerifyException( + "Tool not found: " + functionCall.name().orElse(functionCall.toJson()))); + + public static final MissingToolResolutionStrategy IGNORE = + (invocationContext, functionCall) -> Maybe.empty(); + + public static MissingToolResolutionStrategy respondWithEvent( + BiFunction> eventFactory) { + return eventFactory::apply; + } + + public static MissingToolResolutionStrategy respondWithEventSync( + BiFunction eventFactory) { + return respondWithEvent( + (invocationContext, functionCall) -> + Maybe.just(eventFactory.apply(invocationContext, functionCall))); + } + + Maybe onMissingTool(InvocationContext invocationContext, FunctionCall functionCall); +} diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index d880d7d8..e2aa66c5 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -23,8 +23,10 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.testing.TestUtils; +import com.google.adk.tools.MissingToolResolutionStrategy; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -67,6 +69,37 @@ public void handleFunctionCalls_missingTool() { invocationContext, event, /* tools= */ ImmutableMap.of())); } + @Test + public void handleFunctionCalls_missingTool_recoveryStrategy() { + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder() + .setMissingToolResolutionStrategy( + MissingToolResolutionStrategy.respondWithEventSync( + (ctx, call) -> + Event.builder() + .content( + Content.fromParts( + Part.fromText("tool missing: " + call.name().get()))) + .build())) + .build()); + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), Part.fromFunctionCall("missing_tool", ImmutableMap.of()))) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls(invocationContext, event, /* tools= */ ImmutableMap.of()) + .blockingGet(); + + assertThat(functionResponseEvent).isNotNull(); + assertThat(functionResponseEvent.content().get().parts().get()) + .containsExactly(Part.fromText("tool missing: missing_tool")); + } + @Test public void handleFunctionCalls_singleFunctionCall() { InvocationContext invocationContext = createInvocationContext(createRootAgent());