diff --git a/core/src/main/java/com/google/adk/agents/Callbacks.java b/core/src/main/java/com/google/adk/agents/Callbacks.java index 1c26897b..6c9c2376 100644 --- a/core/src/main/java/com/google/adk/agents/Callbacks.java +++ b/core/src/main/java/com/google/adk/agents/Callbacks.java @@ -129,14 +129,14 @@ public interface BeforeToolCallback extends BeforeToolCallbackBase { * @param invocationContext Invocation context. * @param baseTool Tool instance. * @param input Tool input arguments. - * @param toolContext Tool context. + * @param toolContext Tool context builder. * @return override result, or empty to continue. */ Maybe> call( InvocationContext invocationContext, BaseTool baseTool, Map input, - ToolContext toolContext); + ToolContext.Builder toolContext); } /** @@ -149,7 +149,7 @@ Optional> call( InvocationContext invocationContext, BaseTool baseTool, Map input, - ToolContext toolContext); + ToolContext.Builder toolContext); } interface AfterToolCallbackBase {} @@ -162,7 +162,7 @@ public interface AfterToolCallback extends AfterToolCallbackBase { * @param invocationContext Invocation context. * @param baseTool Tool instance. * @param input Tool input arguments. - * @param toolContext Tool context. + * @param toolContext Tool context builder. * @param response Raw tool response. * @return processed result, or empty to keep original. */ @@ -170,7 +170,7 @@ Maybe> call( InvocationContext invocationContext, BaseTool baseTool, Map input, - ToolContext toolContext, + ToolContext.Builder toolContext, Object response); } @@ -184,7 +184,7 @@ Optional> call( InvocationContext invocationContext, BaseTool baseTool, Map input, - ToolContext toolContext, + ToolContext.Builder toolContext, Object response); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 55f406d3..e03271bc 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -231,7 +231,7 @@ private Flowable callLlm( .runOnModelErrorCallback( new CallbackContext( context, eventForCallbackUsage.actions()), - llmRequest, + llmRequestBuilder, exception) .switchIfEmpty(Single.error(exception)) .toFlowable()) @@ -239,7 +239,10 @@ private Flowable callLlm( llmResp -> { try (Scope innerScope = llmCallSpan.makeCurrent()) { Telemetry.traceCallLlm( - context, eventForCallbackUsage.id(), llmRequest, llmResp); + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp); } }) .doOnError( @@ -269,7 +272,7 @@ private Single> handleBeforeModelCallback( CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); Maybe pluginResult = - context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build()); + context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder); LlmAgent agent = (LlmAgent) context.agent(); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index a952d602..e8b97011 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -147,17 +147,17 @@ public static Maybe handleFunctionCalls( Function> functionCallMapper = functionCall -> { BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = + ToolContext.Builder toolContextBuilder = ToolContext.builder(invocationContext) .functionCallId(functionCall.id().orElse("")) - .toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null))) - .build(); + .toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null))); Map functionArgs = functionCall.args().orElse(ImmutableMap.of()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty(Maybe.defer(() -> callTool(tool, functionArgs, toolContext))); + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContextBuilder) + .switchIfEmpty( + Maybe.defer(() -> callTool(tool, functionArgs, toolContextBuilder.build()))); return maybeFunctionResult .map(Optional::of) @@ -166,7 +166,7 @@ public static Maybe handleFunctionCalls( t -> invocationContext .pluginManager() - .runOnToolErrorCallback(tool, functionArgs, toolContext, t) + .runOnToolErrorCallback(tool, functionArgs, toolContextBuilder, t) .map(Optional::of) .switchIfEmpty(Single.error(t))) .flatMapMaybe( @@ -178,7 +178,7 @@ public static Maybe handleFunctionCalls( invocationContext, tool, functionArgs, - toolContext, + toolContextBuilder, initialFunctionResult); return afterToolResultMaybe @@ -193,7 +193,10 @@ public static Maybe handleFunctionCalls( } Event functionResponseEvent = buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); + tool, + finalFunctionResult, + toolContextBuilder.build(), + invocationContext); return Maybe.just(functionResponseEvent); }); }); @@ -252,21 +255,19 @@ public static Maybe handleFunctionCallsLive( Function> functionCallMapper = functionCall -> { BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .build(); + ToolContext.Builder toolContextBuilder = + ToolContext.builder(invocationContext).functionCallId(functionCall.id().orElse("")); Map functionArgs = functionCall.args().orElse(new HashMap<>()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContextBuilder) .switchIfEmpty( Maybe.defer( () -> processFunctionLive( invocationContext, tool, - toolContext, + toolContextBuilder.build(), functionCall, functionArgs))); @@ -277,7 +278,7 @@ public static Maybe handleFunctionCallsLive( t -> invocationContext .pluginManager() - .runOnToolErrorCallback(tool, functionArgs, toolContext, t) + .runOnToolErrorCallback(tool, functionArgs, toolContextBuilder, t) .map(Optional::ofNullable) .switchIfEmpty(Single.error(t))) .flatMapMaybe( @@ -289,7 +290,7 @@ public static Maybe handleFunctionCallsLive( invocationContext, tool, functionArgs, - toolContext, + toolContextBuilder, initialFunctionResult); return afterToolResultMaybe @@ -304,7 +305,10 @@ public static Maybe handleFunctionCallsLive( } Event functionResponseEvent = buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); + tool, + finalFunctionResult, + toolContextBuilder.build(), + invocationContext); return Maybe.just(functionResponseEvent); }); }); @@ -466,7 +470,7 @@ private static Maybe> maybeInvokeBeforeToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext) { + ToolContext.Builder toolContext) { if (invocationContext.agent() instanceof LlmAgent) { LlmAgent agent = (LlmAgent) invocationContext.agent(); @@ -497,7 +501,7 @@ private static Maybe> maybeInvokeAfterToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext, + ToolContext.Builder toolContext, Map functionResult) { if (invocationContext.agent() instanceof LlmAgent) { LlmAgent agent = (LlmAgent) invocationContext.agent(); diff --git a/core/src/main/java/com/google/adk/plugins/BasePlugin.java b/core/src/main/java/com/google/adk/plugins/BasePlugin.java index 7dd22c42..0a650e75 100644 --- a/core/src/main/java/com/google/adk/plugins/BasePlugin.java +++ b/core/src/main/java/com/google/adk/plugins/BasePlugin.java @@ -122,11 +122,12 @@ public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callba * Callback executed before a request is sent to the model. * * @param callbackContext The context for the current agent call. - * @param llmRequest The prepared request object to be sent to the model. + * @param llmRequest The mutable request builder, allowing modification of the request before it + * is sent to the model. * @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally. */ public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.empty(); } @@ -147,13 +148,13 @@ public Maybe afterModelCallback( * Callback executed when a model call encounters an error. * * @param callbackContext The context for the current agent call. - * @param llmRequest The request that was sent to the model. + * @param llmRequest The mutable request builder for the request that failed. * @param error The exception that was raised. * @return An optional LlmResponse to use instead of propagating the error. Returning Empty to * allow the original error to be raised. */ public Maybe onModelErrorCallback( - CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.empty(); } @@ -162,12 +163,13 @@ public Maybe onModelErrorCallback( * * @param tool The tool instance that is about to be executed. * @param toolArgs The dictionary of arguments to be used for invoking the tool. - * @param toolContext The context specific to the tool execution. + * @param toolContext The mutable tool context builder, allowing modification of the context + * before tool execution. * @return An optional Map to stop the tool execution and return this response immediately. * Returning Empty to proceed normally. */ public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { + BaseTool tool, Map toolArgs, ToolContext.Builder toolContext) { return Maybe.empty(); } @@ -176,7 +178,7 @@ public Maybe> beforeToolCallback( * * @param tool The tool instance that has just been executed. * @param toolArgs The original arguments that were passed to the tool. - * @param toolContext The context specific to the tool execution. + * @param toolContext The mutable tool context builder used for tool execution. * @param result The dictionary returned by the tool invocation. * @return An optional Map to replace the original result from the tool. Returning Empty to use * the original result. @@ -184,7 +186,7 @@ public Maybe> beforeToolCallback( public Maybe> afterToolCallback( BaseTool tool, Map toolArgs, - ToolContext toolContext, + ToolContext.Builder toolContext, Map result) { return Maybe.empty(); } @@ -194,13 +196,16 @@ public Maybe> afterToolCallback( * * @param tool The tool instance that encountered an error. * @param toolArgs The arguments that were passed to the tool. - * @param toolContext The context specific to the tool execution. + * @param toolContext The mutable tool context builder for the tool call that failed. * @param error The exception that was raised during tool execution. * @return An optional Map to be used as the tool response instead of propagating the error. * Returning Empty to allow the original error to be raised. */ public Maybe> onToolErrorCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + BaseTool tool, + Map toolArgs, + ToolContext.Builder toolContext, + Throwable error) { return Maybe.empty(); } } diff --git a/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java b/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java index d0ce3ab4..63b287cb 100644 --- a/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java @@ -151,14 +151,15 @@ public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callba @Override public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.fromAction( () -> { + LlmRequest request = llmRequest.build(); log("🧠 LLM REQUEST"); - log(" Model: " + llmRequest.model().orElse("default")); + log(" Model: " + request.model().orElse("default")); log(" Agent: " + callbackContext.agentName()); - llmRequest + request .getFirstSystemInstruction() .ifPresent( sysInstruction -> { @@ -170,8 +171,8 @@ public Maybe beforeModelCallback( log(" System Instruction: '" + truncatedInstruction + "'"); }); - if (!llmRequest.tools().isEmpty()) { - String toolNames = String.join(", ", llmRequest.tools().keySet()); + if (!request.tools().isEmpty()) { + String toolNames = String.join(", ", request.tools().keySet()); log(" Available Tools: [" + toolNames + "]"); } }); @@ -211,7 +212,7 @@ public Maybe afterModelCallback( @Override public Maybe onModelErrorCallback( - CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { log("🧠 LLM ERROR"); @@ -223,13 +224,14 @@ public Maybe onModelErrorCallback( @Override public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { + BaseTool tool, Map toolArgs, ToolContext.Builder toolContext) { return Maybe.fromAction( () -> { + ToolContext tc = toolContext.build(); log("🔧 TOOL STARTING"); log(" Tool Name: " + tool.name()); - log(" Agent: " + toolContext.agentName()); - toolContext.functionCallId().ifPresent(id -> log(" Function Call ID: " + id)); + log(" Agent: " + tc.agentName()); + tc.functionCallId().ifPresent(id -> log(" Function Call ID: " + id)); log(" Arguments: " + formatArgs(toolArgs)); }); } @@ -238,27 +240,32 @@ public Maybe> beforeToolCallback( public Maybe> afterToolCallback( BaseTool tool, Map toolArgs, - ToolContext toolContext, + ToolContext.Builder toolContext, Map result) { return Maybe.fromAction( () -> { + ToolContext tc = toolContext.build(); log("🔧 TOOL COMPLETED"); log(" Tool Name: " + tool.name()); - log(" Agent: " + toolContext.agentName()); - toolContext.functionCallId().ifPresent(id -> log(" Function Call ID: " + id)); + log(" Agent: " + tc.agentName()); + tc.functionCallId().ifPresent(id -> log(" Function Call ID: " + id)); log(" Result: " + formatArgs(result)); }); } @Override public Maybe> onToolErrorCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + BaseTool tool, + Map toolArgs, + ToolContext.Builder toolContext, + Throwable error) { return Maybe.fromAction( () -> { + ToolContext tc = toolContext.build(); log("🔧 TOOL ERROR"); log(" Tool Name: " + tool.name()); - log(" Agent: " + toolContext.agentName()); - toolContext.functionCallId().ifPresent(id -> log(" Function Call ID: " + id)); + log(" Agent: " + tc.agentName()); + tc.functionCallId().ifPresent(id -> log(" Function Call ID: " + id)); log(" Arguments: " + formatArgs(toolArgs)); log(" Error: " + error.getMessage()); logger.error("[{}] Tool Error", name, error); diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 135168e9..8ecee467 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -127,7 +127,7 @@ public Maybe runAfterAgentCallback(BaseAgent agent, CallbackContext cal } public Maybe runBeforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return runMaybeCallbacks( plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback"); } @@ -139,14 +139,14 @@ public Maybe runAfterModelCallback( } public Maybe runOnModelErrorCallback( - CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return runMaybeCallbacks( plugin -> plugin.onModelErrorCallback(callbackContext, llmRequest, error), "onModelErrorCallback"); } public Maybe> runBeforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { + BaseTool tool, Map toolArgs, ToolContext.Builder toolContext) { return runMaybeCallbacks( plugin -> plugin.beforeToolCallback(tool, toolArgs, toolContext), "beforeToolCallback"); } @@ -154,7 +154,7 @@ public Maybe> runBeforeToolCallback( public Maybe> runAfterToolCallback( BaseTool tool, Map toolArgs, - ToolContext toolContext, + ToolContext.Builder toolContext, Map result) { return runMaybeCallbacks( plugin -> plugin.afterToolCallback(tool, toolArgs, toolContext, result), @@ -162,7 +162,10 @@ public Maybe> runAfterToolCallback( } public Maybe> runOnToolErrorCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + BaseTool tool, + Map toolArgs, + ToolContext.Builder toolContext, + Throwable error) { return runMaybeCallbacks( plugin -> plugin.onToolErrorCallback(tool, toolArgs, toolContext, error), "onToolErrorCallback"); diff --git a/core/src/test/java/com/google/adk/agents/CallbacksTest.java b/core/src/test/java/com/google/adk/agents/CallbacksTest.java index 85c64106..91d52a6a 100644 --- a/core/src/test/java/com/google/adk/agents/CallbacksTest.java +++ b/core/src/test/java/com/google/adk/agents/CallbacksTest.java @@ -987,7 +987,7 @@ public void handleFunctionCalls_withChainedToolCallbacks_overridesResultAndPasse Callbacks.BeforeToolCallbackSync bc2 = (invCtx, toolName, args, currentToolCtx) -> { - currentToolCtx.state().putAll(stateAddedByBc2); + currentToolCtx.build().state().putAll(stateAddedByBc2); return Optional.empty(); }; diff --git a/core/src/test/java/com/google/adk/plugins/BasePluginTest.java b/core/src/test/java/com/google/adk/plugins/BasePluginTest.java index 9a4a243c..56f77022 100644 --- a/core/src/test/java/com/google/adk/plugins/BasePluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/BasePluginTest.java @@ -43,9 +43,9 @@ private static class TestPlugin extends BasePlugin { private final CallbackContext callbackContext = Mockito.mock(CallbackContext.class); private final Content content = Content.builder().build(); private final Event event = Mockito.mock(Event.class); - private final LlmRequest llmRequest = LlmRequest.builder().build(); + private final LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); private final LlmResponse llmResponse = LlmResponse.builder().build(); - private final ToolContext toolContext = Mockito.mock(ToolContext.class); + private final ToolContext.Builder toolContextBuilder = ToolContext.builder(invocationContext); @Test public void onUserMessageCallback_returnsEmptyMaybe() { @@ -79,7 +79,7 @@ public void afterAgentCallback_returnsEmptyMaybe() { @Test public void beforeModelCallback_returnsEmptyMaybe() { - plugin.beforeModelCallback(callbackContext, llmRequest).test().assertResult(); + plugin.beforeModelCallback(callbackContext, llmRequestBuilder).test().assertResult(); } @Test @@ -90,20 +90,20 @@ public void afterModelCallback_returnsEmptyMaybe() { @Test public void onModelErrorCallback_returnsEmptyMaybe() { plugin - .onModelErrorCallback(callbackContext, llmRequest, new RuntimeException()) + .onModelErrorCallback(callbackContext, llmRequestBuilder, new RuntimeException()) .test() .assertResult(); } @Test public void beforeToolCallback_returnsEmptyMaybe() { - plugin.beforeToolCallback(null, new HashMap<>(), toolContext).test().assertResult(); + plugin.beforeToolCallback(null, new HashMap<>(), toolContextBuilder).test().assertResult(); } @Test public void afterToolCallback_returnsEmptyMaybe() { plugin - .afterToolCallback(null, new HashMap<>(), toolContext, new HashMap<>()) + .afterToolCallback(null, new HashMap<>(), toolContextBuilder, new HashMap<>()) .test() .assertResult(); } @@ -111,7 +111,7 @@ public void afterToolCallback_returnsEmptyMaybe() { @Test public void onToolErrorCallback_returnsEmptyMaybe() { plugin - .onToolErrorCallback(null, new HashMap<>(), toolContext, new RuntimeException()) + .onToolErrorCallback(null, new HashMap<>(), toolContextBuilder, new RuntimeException()) .test() .assertResult(); } diff --git a/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java b/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java index 4c90c11b..ae8f5be7 100644 --- a/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java @@ -37,6 +37,7 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; import java.util.Optional; import org.junit.Before; import org.junit.Rule; @@ -57,7 +58,6 @@ public class LoggingPluginTest { @Mock private BaseAgent mockAgent; @Mock private CallbackContext mockCallbackContext; @Mock private BaseTool mockTool; - @Mock private ToolContext mockToolContext; private final Content content = Content.builder().build(); private final Session session = Session.builder("session_id").build(); @@ -69,8 +69,30 @@ public class LoggingPluginTest { .actions(EventActions.builder().build()) .longRunningToolIds(Optional.empty()) .build(); - private final LlmRequest llmRequest = - LlmRequest.builder().model("default").contents(ImmutableList.of()).build(); + private final ToolContext.Builder toolContextBuilder = + ToolContext.builder( + InvocationContext.builder() + .agent( + new BaseAgent( + "agent_name", + "description", + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of()) { + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + }) + .session(session) + .build()); + private final LlmRequest.Builder llmRequestBuilder = + LlmRequest.builder().model("default").contents(ImmutableList.of()); private final LlmResponse llmResponse = LlmResponse.builder().build(); private final ImmutableMap toolArgs = ImmutableMap.of(); private final ImmutableMap toolResult = ImmutableMap.of(); @@ -90,8 +112,6 @@ public void setUp() { when(mockCallbackContext.branch()).thenReturn(Optional.empty()); when(mockTool.name()).thenReturn("tool_name"); - when(mockToolContext.agentName()).thenReturn("agent_name"); - when(mockToolContext.functionCallId()).thenReturn(Optional.empty()); } @Test @@ -175,7 +195,10 @@ public void afterAgentCallback_runsWithoutError() { @Test public void beforeModelCallback_runsWithoutError() { - loggingPlugin.beforeModelCallback(mockCallbackContext, llmRequest).test().assertComplete(); + loggingPlugin + .beforeModelCallback(mockCallbackContext, llmRequestBuilder) + .test() + .assertComplete(); } @Test @@ -184,8 +207,7 @@ public void beforeModelCallback_longSystemInstruction() { .beforeModelCallback( mockCallbackContext, LlmRequest.builder() - .appendInstructions(ImmutableList.of("all work and no play".repeat(1000))) - .build()) + .appendInstructions(ImmutableList.of("all work and no play".repeat(1000)))) .test() .assertComplete(); } @@ -194,8 +216,7 @@ public void beforeModelCallback_longSystemInstruction() { public void beforeModelCallback_tools() { loggingPlugin .beforeModelCallback( - mockCallbackContext, - LlmRequest.builder().appendTools(ImmutableList.of(mockTool)).build()) + mockCallbackContext, LlmRequest.builder().appendTools(ImmutableList.of(mockTool))) .test() .assertComplete(); } @@ -231,20 +252,23 @@ public void afterModelCallback_usageMetadata() { @Test public void onModelErrorCallback_runsWithoutError() { loggingPlugin - .onModelErrorCallback(mockCallbackContext, llmRequest, throwable) + .onModelErrorCallback(mockCallbackContext, llmRequestBuilder, throwable) .test() .assertComplete(); } @Test public void beforeToolCallback_runsWithoutError() { - loggingPlugin.beforeToolCallback(mockTool, toolArgs, mockToolContext).test().assertComplete(); + loggingPlugin + .beforeToolCallback(mockTool, toolArgs, toolContextBuilder) + .test() + .assertComplete(); } @Test public void afterToolCallback_runsWithoutError() { loggingPlugin - .afterToolCallback(mockTool, toolArgs, mockToolContext, toolResult) + .afterToolCallback(mockTool, toolArgs, toolContextBuilder, toolResult) .test() .assertComplete(); } @@ -252,7 +276,7 @@ public void afterToolCallback_runsWithoutError() { @Test public void onToolErrorCallback_runsWithoutError() { loggingPlugin - .onToolErrorCallback(mockTool, toolArgs, mockToolContext, throwable) + .onToolErrorCallback(mockTool, toolArgs, toolContextBuilder, throwable) .test() .assertComplete(); } diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4737d6cd..bd73fe05 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -236,18 +236,18 @@ public void runAfterAgentCallback_singlePlugin() { @Test public void runBeforeModelCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); - LlmRequest llmRequest = LlmRequest.builder().build(); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); LlmResponse llmResponse = LlmResponse.builder().build(); when(plugin1.beforeModelCallback(any(), any())).thenReturn(Maybe.just(llmResponse)); pluginManager.registerPlugin(plugin1); pluginManager - .runBeforeModelCallback(mockCallbackContext, llmRequest) + .runBeforeModelCallback(mockCallbackContext, llmRequestBuilder) .test() .assertResult(llmResponse); - verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequest); + verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequestBuilder); } @Test @@ -269,7 +269,7 @@ public void runAfterModelCallback_singlePlugin() { @Test public void runOnModelErrorCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); - LlmRequest llmRequest = LlmRequest.builder().build(); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); Throwable mockThrowable = mock(Throwable.class); LlmResponse llmResponse = LlmResponse.builder().build(); @@ -277,63 +277,63 @@ public void runOnModelErrorCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runOnModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable) + .runOnModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable) .test() .assertResult(llmResponse); - verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable); + verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable); } @Test public void runBeforeToolCallback_singlePlugin() { BaseTool mockTool = mock(BaseTool.class); ImmutableMap toolArgs = ImmutableMap.of(); - ToolContext mockToolContext = mock(ToolContext.class); + ToolContext.Builder toolContextBuilder = ToolContext.builder(mockInvocationContext); when(plugin1.beforeToolCallback(any(), any(), any())).thenReturn(Maybe.just(toolArgs)); pluginManager.registerPlugin(plugin1); pluginManager - .runBeforeToolCallback(mockTool, toolArgs, mockToolContext) + .runBeforeToolCallback(mockTool, toolArgs, toolContextBuilder) .test() .assertResult(toolArgs); - verify(plugin1).beforeToolCallback(mockTool, toolArgs, mockToolContext); + verify(plugin1).beforeToolCallback(mockTool, toolArgs, toolContextBuilder); } @Test public void runAfterToolCallback_singlePlugin() { BaseTool mockTool = mock(BaseTool.class); ImmutableMap toolArgs = ImmutableMap.of(); - ToolContext mockToolContext = mock(ToolContext.class); + ToolContext.Builder toolContextBuilder = ToolContext.builder(mockInvocationContext); ImmutableMap result = ImmutableMap.of(); when(plugin1.afterToolCallback(any(), any(), any(), any())).thenReturn(Maybe.just(result)); pluginManager.registerPlugin(plugin1); pluginManager - .runAfterToolCallback(mockTool, toolArgs, mockToolContext, result) + .runAfterToolCallback(mockTool, toolArgs, toolContextBuilder, result) .test() .assertResult(result); - verify(plugin1).afterToolCallback(mockTool, toolArgs, mockToolContext, result); + verify(plugin1).afterToolCallback(mockTool, toolArgs, toolContextBuilder, result); } @Test public void runOnToolErrorCallback_singlePlugin() { BaseTool mockTool = mock(BaseTool.class); ImmutableMap toolArgs = ImmutableMap.of(); - ToolContext mockToolContext = mock(ToolContext.class); + ToolContext.Builder toolContextBuilder = ToolContext.builder(mockInvocationContext); Throwable mockThrowable = mock(Throwable.class); ImmutableMap result = ImmutableMap.of(); when(plugin1.onToolErrorCallback(any(), any(), any(), any())).thenReturn(Maybe.just(result)); pluginManager.registerPlugin(plugin1); pluginManager - .runOnToolErrorCallback(mockTool, toolArgs, mockToolContext, mockThrowable) + .runOnToolErrorCallback(mockTool, toolArgs, toolContextBuilder, mockThrowable) .test() .assertResult(result); - verify(plugin1).onToolErrorCallback(mockTool, toolArgs, mockToolContext, mockThrowable); + verify(plugin1).onToolErrorCallback(mockTool, toolArgs, toolContextBuilder, mockThrowable); } } diff --git a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java index 1164709a..70f90f42 100644 --- a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java +++ b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java @@ -71,7 +71,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { if (!isReplayModeOn(callbackContext)) { return Maybe.empty(); } @@ -95,18 +95,19 @@ public Maybe beforeModelCallback( @Override public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - if (!isReplayModeOn(toolContext)) { + BaseTool tool, Map toolArgs, ToolContext.Builder toolContext) { + ToolContext tc = toolContext.build(); + if (!isReplayModeOn(tc)) { return Maybe.empty(); } - InvocationReplayState state = getInvocationState(toolContext); + InvocationReplayState state = getInvocationState(tc); if (state == null) { throw new ReplayConfigError( "Replay state not initialized. Ensure beforeRunCallback created it."); } - String agentName = toolContext.agentName(); + String agentName = tc.agentName(); // Verify and get the next tool recording for this specific agent ToolRecording recording = @@ -116,7 +117,7 @@ public Maybe> beforeToolCallback( // TODO: support replay requests and responses from AgentTool. // For now, execute the tool normally to maintain side effects try { - Map liveResult = tool.runAsync(toolArgs, toolContext).blockingGet(); + Map liveResult = tool.runAsync(toolArgs, tc).blockingGet(); logger.debug("Tool {} executed during replay with result: {}", tool.name(), liveResult); } catch (Exception e) { logger.warn("Error executing tool {} during replay", tool.name(), e); @@ -261,7 +262,7 @@ private Recording getNextRecordingForAgent(InvocationReplayState state, String a } private LlmRecording verifyAndGetNextLlmRecordingForAgent( - InvocationReplayState state, String agentName, LlmRequest llmRequest) { + InvocationReplayState state, String agentName, LlmRequest.Builder llmRequest) { int currentAgentIndex = state.getAgentReplayIndex(agentName); Recording expectedRecording = getNextRecordingForAgent(state, agentName); @@ -278,7 +279,7 @@ private LlmRecording verifyAndGetNextLlmRecordingForAgent( // Strict verification of LLM request if (llmRecording.llmRequest().isPresent()) { verifyLlmRequestMatch( - llmRecording.llmRequest().get(), llmRequest, agentName, currentAgentIndex); + llmRecording.llmRequest().get(), llmRequest.build(), agentName, currentAgentIndex); } return llmRecording; diff --git a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java index fe4d2a0b..f29298bc 100644 --- a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java +++ b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java @@ -107,7 +107,7 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws when(callbackContext.invocationId()).thenReturn("test-invocation"); when(callbackContext.agentName()).thenReturn("test_agent"); - LlmRequest request = + var request = LlmRequest.builder() .model("gemini-2.0-flash") .contents( @@ -115,8 +115,7 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws Content.builder() .role("user") .parts(Part.builder().text("Hello").build()) - .build())) - .build(); + .build())); // Step 4: Verify expected response is returned var result = plugin.beforeModelCallback(callbackContext, request).blockingGet(); @@ -162,7 +161,7 @@ void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception { when(callbackContext.invocationId()).thenReturn("test-invocation"); when(callbackContext.agentName()).thenReturn("test_agent"); - LlmRequest request = + var request = LlmRequest.builder() .model("gemini-2.0-flash") // Different model .contents( @@ -170,8 +169,7 @@ void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception { Content.builder() .role("user") .parts(Part.builder().text("Hello").build()) - .build())) - .build(); + .build())); // Step 4: Verify result is empty var result = plugin.beforeModelCallback(callbackContext, request).blockingGet();