Skip to content

Commit 787a7c2

Browse files
committed
feat: add missing tool resolution strategy
1 parent 493c064 commit 787a7c2

File tree

4 files changed

+134
-13
lines changed

4 files changed

+134
-13
lines changed

core/src/main/java/com/google/adk/agents/RunConfig.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.agents;
1818

19+
import com.google.adk.tools.MissingToolResolutionStrategy;
1920
import com.google.auto.value.AutoValue;
2021
import com.google.common.collect.ImmutableList;
2122
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -70,6 +71,8 @@ public enum ToolExecutionMode {
7071

7172
public abstract int maxLlmCalls();
7273

74+
public abstract MissingToolResolutionStrategy missingToolResolutionStrategy();
75+
7376
public abstract Builder toBuilder();
7477

7578
public static Builder builder() {
@@ -78,6 +81,7 @@ public static Builder builder() {
7881
.setResponseModalities(ImmutableList.of())
7982
.setStreamingMode(StreamingMode.NONE)
8083
.setToolExecutionMode(ToolExecutionMode.NONE)
84+
.setMissingToolResolutionStrategy(MissingToolResolutionStrategy.THROW_EXCEPTION)
8185
.setMaxLlmCalls(500);
8286
}
8387

@@ -90,7 +94,8 @@ public static Builder builder(RunConfig runConfig) {
9094
.setResponseModalities(runConfig.responseModalities())
9195
.setSpeechConfig(runConfig.speechConfig())
9296
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
93-
.setInputAudioTranscription(runConfig.inputAudioTranscription());
97+
.setInputAudioTranscription(runConfig.inputAudioTranscription())
98+
.setMissingToolResolutionStrategy(runConfig.missingToolResolutionStrategy());
9499
}
95100

96101
/** Builder for {@link RunConfig}. */
@@ -123,6 +128,10 @@ public abstract Builder setInputAudioTranscription(
123128
@CanIgnoreReturnValue
124129
public abstract Builder setMaxLlmCalls(int maxLlmCalls);
125130

131+
@CanIgnoreReturnValue
132+
public abstract Builder setMissingToolResolutionStrategy(
133+
MissingToolResolutionStrategy missingToolResolutionStrategy);
134+
126135
abstract RunConfig autoBuild();
127136

128137
public RunConfig build() {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
import com.google.adk.events.EventActions;
3131
import com.google.adk.tools.BaseTool;
3232
import com.google.adk.tools.FunctionTool;
33+
import com.google.adk.tools.MissingToolResolutionStrategy;
3334
import com.google.adk.tools.ToolConfirmation;
3435
import com.google.adk.tools.ToolContext;
35-
import com.google.common.base.VerifyException;
3636
import com.google.common.collect.ImmutableList;
3737
import com.google.common.collect.ImmutableMap;
3838
import com.google.genai.types.Content;
@@ -137,10 +137,16 @@ public static Maybe<Event> handleFunctionCalls(
137137
Map<String, BaseTool> tools,
138138
Map<String, ToolConfirmation> toolConfirmations) {
139139
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
140-
140+
MissingToolResolutionStrategy missingToolResolutionStrategy =
141+
invocationContext.runConfig().missingToolResolutionStrategy();
142+
ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
143+
ImmutableList.Builder<FunctionCall> validCalls = ImmutableList.builder();
141144
for (FunctionCall functionCall : functionCalls) {
142145
if (!tools.containsKey(functionCall.name().get())) {
143-
throw new VerifyException("Tool not found: " + functionCall.name().get());
146+
missingTools.add(
147+
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
148+
} else {
149+
validCalls.add(functionCall);
144150
}
145151
}
146152

@@ -202,12 +208,16 @@ public static Maybe<Event> handleFunctionCalls(
202208
Flowable<Event> functionResponseEventsFlowable;
203209
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
204210
functionResponseEventsFlowable =
205-
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
211+
Flowable.fromIterable(validCalls.build()).concatMapMaybe(functionCallMapper);
206212
} else {
207213
functionResponseEventsFlowable =
208-
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
214+
Flowable.fromIterable(validCalls.build()).flatMapMaybe(functionCallMapper);
209215
}
210-
return functionResponseEventsFlowable
216+
Flowable<Event> missingToolsFlowable =
217+
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
218+
Flowable<Event> allEventsFlowable =
219+
Flowable.concat(missingToolsFlowable, functionResponseEventsFlowable);
220+
return allEventsFlowable
211221
.toList()
212222
.flatMapMaybe(
213223
events -> {
@@ -242,13 +252,19 @@ public static Maybe<Event> handleFunctionCalls(
242252
public static Maybe<Event> handleFunctionCallsLive(
243253
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
244254
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
255+
MissingToolResolutionStrategy missingToolResolutionStrategy =
256+
invocationContext.runConfig().missingToolResolutionStrategy();
245257

258+
ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
259+
ImmutableList.Builder<FunctionCall> validCalls = ImmutableList.builder();
246260
for (FunctionCall functionCall : functionCalls) {
247261
if (!tools.containsKey(functionCall.name().get())) {
248-
throw new VerifyException("Tool not found: " + functionCall.name().get());
262+
missingTools.add(
263+
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
264+
} else {
265+
validCalls.add(functionCall);
249266
}
250267
}
251-
252268
Function<FunctionCall, Maybe<Event>> functionCallMapper =
253269
functionCall -> {
254270
BaseTool tool = tools.get(functionCall.name().get());
@@ -311,17 +327,20 @@ public static Maybe<Event> handleFunctionCallsLive(
311327
};
312328

313329
Flowable<Event> responseEventsFlowable;
314-
315330
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
316331
responseEventsFlowable =
317-
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
332+
Flowable.fromIterable(validCalls.build()).concatMapMaybe(functionCallMapper);
318333

319334
} else {
320335
responseEventsFlowable =
321-
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
336+
Flowable.fromIterable(validCalls.build()).flatMapMaybe(functionCallMapper);
322337
}
338+
Flowable<Event> missingToolsFlowable =
339+
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
340+
Flowable<Event> allEventsFlowable =
341+
Flowable.concat(missingToolsFlowable, responseEventsFlowable);
323342

324-
return responseEventsFlowable
343+
return allEventsFlowable
325344
.toList()
326345
.flatMapMaybe(
327346
events -> {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package com.google.adk.tools;
2+
3+
import com.google.adk.agents.InvocationContext;
4+
import com.google.adk.events.Event;
5+
import com.google.common.base.VerifyException;
6+
import com.google.genai.types.FunctionCall;
7+
import io.reactivex.rxjava3.core.Maybe;
8+
import java.util.function.BiFunction;
9+
10+
public interface MissingToolResolutionStrategy {
11+
public static final MissingToolResolutionStrategy THROW_EXCEPTION =
12+
new MissingToolResolutionStrategy() {
13+
@Override
14+
public Maybe<Event> onMissingTool(
15+
InvocationContext invocationContext, FunctionCall functionCall) {
16+
throw new VerifyException(
17+
"Tool not found: " + functionCall.name().orElse(functionCall.toJson()));
18+
}
19+
};
20+
21+
public static final MissingToolResolutionStrategy RETURN_ERROR =
22+
new MissingToolResolutionStrategy() {
23+
@Override
24+
public Maybe<Event> onMissingTool(
25+
InvocationContext invocationContext, FunctionCall functionCall) {
26+
return Maybe.error(
27+
new VerifyException(
28+
"Tool not found: " + functionCall.name().orElse(functionCall.toJson())));
29+
}
30+
};
31+
32+
public static final MissingToolResolutionStrategy IGNORE =
33+
new MissingToolResolutionStrategy() {
34+
@Override
35+
public Maybe<Event> onMissingTool(
36+
InvocationContext invocationContext, FunctionCall functionCall) {
37+
return Maybe.empty();
38+
}
39+
};
40+
41+
public static MissingToolResolutionStrategy respondWithEvent(
42+
BiFunction<InvocationContext, FunctionCall, Maybe<Event>> eventFactory) {
43+
return new MissingToolResolutionStrategy() {
44+
@Override
45+
public Maybe<Event> onMissingTool(
46+
InvocationContext invocationContext, FunctionCall functionCall) {
47+
return eventFactory.apply(invocationContext, functionCall);
48+
}
49+
};
50+
}
51+
52+
public static MissingToolResolutionStrategy respondWithEventSync(
53+
BiFunction<InvocationContext, FunctionCall, Event> eventFactory) {
54+
return respondWithEvent(
55+
(invocationContext, functionCall) ->
56+
Maybe.just(eventFactory.apply(invocationContext, functionCall)));
57+
}
58+
59+
Maybe<Event> onMissingTool(InvocationContext invocationContext, FunctionCall functionCall);
60+
}

core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
import static org.junit.Assert.assertThrows;
2424

2525
import com.google.adk.agents.InvocationContext;
26+
import com.google.adk.agents.RunConfig;
2627
import com.google.adk.events.Event;
2728
import com.google.adk.testing.TestUtils;
29+
import com.google.adk.tools.MissingToolResolutionStrategy;
2830
import com.google.common.collect.ImmutableList;
2931
import com.google.common.collect.ImmutableMap;
3032
import com.google.genai.types.Content;
@@ -67,6 +69,37 @@ public void handleFunctionCalls_missingTool() {
6769
invocationContext, event, /* tools= */ ImmutableMap.of()));
6870
}
6971

72+
@Test
73+
public void handleFunctionCalls_missingTool_recoveryStrategy() {
74+
InvocationContext invocationContext =
75+
createInvocationContext(
76+
createRootAgent(),
77+
RunConfig.builder()
78+
.setMissingToolResolutionStrategy(
79+
MissingToolResolutionStrategy.respondWithEventSync(
80+
(ctx, call) ->
81+
Event.builder()
82+
.content(
83+
Content.fromParts(
84+
Part.fromText("tool missing: " + call.name().get())))
85+
.build()))
86+
.build());
87+
Event event =
88+
createEvent("event").toBuilder()
89+
.content(
90+
Content.fromParts(
91+
Part.fromText("..."), Part.fromFunctionCall("missing_tool", ImmutableMap.of())))
92+
.build();
93+
94+
Event functionResponseEvent =
95+
Functions.handleFunctionCalls(invocationContext, event, /* tools= */ ImmutableMap.of())
96+
.blockingGet();
97+
98+
assertThat(functionResponseEvent).isNotNull();
99+
assertThat(functionResponseEvent.content().get().parts().get())
100+
.containsExactly(Part.fromText("tool missing: missing_tool"));
101+
}
102+
70103
@Test
71104
public void handleFunctionCalls_singleFunctionCall() {
72105
InvocationContext invocationContext = createInvocationContext(createRootAgent());

0 commit comments

Comments
 (0)