From 8e69a58df4eadeccbb100b7264bb518a46b61fd7 Mon Sep 17 00:00:00 2001 From: lwangverizon Date: Tue, 13 Jan 2026 21:23:22 -0800 Subject: [PATCH 1/3] feat: Add support to automatically create a session if one does not exist feature/auto-create-new-session Merge https://github.com/google/adk-python/pull/4072 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **2. Or, if no issue exists, describe the change:** **Problem:** When building frontend applications with ADK, there's a limitation where frontends cannot always guarantee that `create_session` is called before initiating a conversation. This creates friction in the user experience because: - Users may refresh the page or navigate directly to a conversation URL with a specific session_id - Frontend state management may lose track of whether a session was already created - Mobile apps or single-page applications have complex lifecycle management where ensuring `create_session` is called first adds unnecessary complexity - This forces developers to implement additional logic to check session existence before every conversation Currently, if `get_session` is called with a non-existent session_id, it returns `None`, requiring the frontend to explicitly handle this case and call `create_session` separately. **Solution:** Modified the `get_session` method in `DatabaseSessionService` to automatically create a session if it doesn't exist in the database. This "get or create" pattern is common in many frameworks and provides a more developer-friendly API. The implementation: 1. Attempts to fetch the session from the database 2. If the session doesn't exist (returns `None`), automatically calls `create_session` with the provided parameters 3. Retrieves and returns the newly created session 4. Maintains backward compatibility - existing code continues to work without changes This allows frontends to simply call `get_session` with a session_id and be confident that the session will be available, regardless of whether it was previously created. **Benefits:** - Simplifies frontend integration by removing the need to track session creation state - Reduces API calls (no need to check existence before calling get_session) - Follows the principle of least surprise - getting a session with an ID should work reliably - No breaking changes to existing code that checks for `None` return values ### Testing Plan **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. **pytest results:** COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4072 from lwangverizon:feature/auto-create-new-session 5475c6ae91d12332598d521302736eb1db79a8be PiperOrigin-RevId: 856019482 --- src/google/adk/runners.py | 64 ++++++++++++----- tests/unittests/test_runners.py | 121 ++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 18 deletions(-) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index cbf2c59548..97eb85dfdc 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -149,6 +149,7 @@ def __init__( memory_service: Optional[BaseMemoryService] = None, credential_service: Optional[BaseCredentialService] = None, plugin_close_timeout: float = 5.0, + auto_create_session: bool = False, ): """Initializes the Runner. @@ -175,6 +176,9 @@ def __init__( memory_service: The memory service for the runner. credential_service: The credential service for the runner. plugin_close_timeout: The timeout in seconds for plugin close methods. + auto_create_session: Whether to automatically create a session when + not found. Defaults to False. If False, a missing session raises + ValueError with a helpful message. Raises: ValueError: If `app` is provided along with `agent` or `plugins`, or if @@ -195,6 +199,7 @@ def __init__( self.plugin_manager = PluginManager( plugins=plugins, close_timeout=plugin_close_timeout ) + self.auto_create_session = auto_create_session ( self._agent_origin_app_name, self._agent_origin_dir, @@ -343,9 +348,43 @@ def _format_session_not_found_message(self, session_id: str) -> str: return message return ( f'{message}. {self._app_name_alignment_hint} ' - 'The mismatch prevents the runner from locating the session.' + 'The mismatch prevents the runner from locating the session. ' + 'To automatically create a session when missing, set ' + 'auto_create_session=True when constructing the runner.' ) + async def _get_or_create_session( + self, *, user_id: str, session_id: str + ) -> Session: + """Gets the session or creates it if auto-creation is enabled. + + This helper first attempts to retrieve the session. If not found and + auto_create_session is True, it creates a new session with the provided + identifiers. Otherwise, it raises a ValueError with a helpful message. + + Args: + user_id: The user ID of the session. + session_id: The session ID of the session. + + Returns: + The existing or newly created `Session`. + + Raises: + ValueError: If the session is not found and auto_create_session is False. + """ + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + if not session: + if self.auto_create_session: + session = await self.session_service.create_session( + app_name=self.app_name, user_id=user_id, session_id=session_id + ) + else: + message = self._format_session_not_found_message(session_id) + raise ValueError(message) + return session + def run( self, *, @@ -455,12 +494,9 @@ async def _run_with_trace( invocation_id: Optional[str] = None, ) -> AsyncGenerator[Event, None]: with tracer.start_as_current_span('invocation'): - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id + session = await self._get_or_create_session( + user_id=user_id, session_id=session_id ) - if not session: - message = self._format_session_not_found_message(session_id) - raise ValueError(message) if not invocation_id and not new_message: raise ValueError( 'Running an agent requires either a new_message or an ' @@ -534,12 +570,9 @@ async def rewind_async( rewind_before_invocation_id: str, ) -> None: """Rewinds the session to before the specified invocation.""" - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id + session = await self._get_or_create_session( + user_id=user_id, session_id=session_id ) - if not session: - raise ValueError(f'Session not found: {session_id}') - rewind_event_index = -1 for i, event in enumerate(session.events): if event.invocation_id == rewind_before_invocation_id: @@ -967,14 +1000,9 @@ async def run_live( stacklevel=2, ) if not session: - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id + session = await self._get_or_create_session( + user_id=user_id, session_id=session_id ) - if not session: - raise ValueError( - f'Session not found for user id: {user_id} and session id:' - f' {session_id}' - ) invocation_context = self._new_invocation_context_for_live( session, live_request_queue=live_request_queue, diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 710ca90d31..bb44ce73d6 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -68,6 +68,24 @@ async def _run_async_impl( ) +class MockLiveAgent(BaseAgent): + """Mock live agent for unit testing.""" + + def __init__(self, name: str): + super().__init__(name=name, sub_agents=[]) + + async def _run_live_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="live hello")] + ), + ) + + class MockLlmAgent(LlmAgent): """Mock LLM agent for unit testing.""" @@ -237,6 +255,109 @@ def _infer_agent_origin( assert "Ensure the runner app_name matches" in message +@pytest.mark.asyncio +async def test_session_auto_creation(): + + class RunnerWithMismatch(Runner): + + def _infer_agent_origin( + self, agent: BaseAgent + ) -> tuple[Optional[str], Optional[Path]]: + del agent + return "expected_app", Path("/workspace/agents/expected_app") + + session_service = InMemorySessionService() + runner = RunnerWithMismatch( + app_name="expected_app", + agent=MockLlmAgent("test_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + auto_create_session=True, + ) + + agen = runner.run_async( + user_id="user", + session_id="missing", + new_message=types.Content(role="user", parts=[types.Part(text="hi")]), + ) + + event = await agen.__anext__() + await agen.aclose() + + # Verify that session_id="missing" doesn't error out - session is auto-created + assert event.author == "test_agent" + assert event.content.parts[0].text == "Test LLM response" + + +@pytest.mark.asyncio +async def test_rewind_auto_create_session_on_missing_session(): + """When auto_create_session=True, rewind should create session if missing. + + The newly created session won't contain the target invocation, so + `rewind_async` should raise an Invocation ID not found error (rather than + a session not found error), demonstrating auto-creation occurred. + """ + session_service = InMemorySessionService() + runner = Runner( + app_name="auto_create_app", + agent=MockLlmAgent("agent_for_rewind"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + auto_create_session=True, + ) + + with pytest.raises(ValueError, match=r"Invocation ID not found: inv_missing"): + await runner.rewind_async( + user_id="user", + session_id="missing", + rewind_before_invocation_id="inv_missing", + ) + + # Verify the session actually exists now due to auto-creation. + session = await session_service.get_session( + app_name="auto_create_app", user_id="user", session_id="missing" + ) + assert session is not None + assert session.app_name == "auto_create_app" + + +@pytest.mark.asyncio +async def test_run_live_auto_create_session(): + """run_live should auto-create session when missing and yield events.""" + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name="live_app", + agent=MockLiveAgent("live_agent"), + session_service=session_service, + artifact_service=artifact_service, + auto_create_session=True, + ) + + # An empty LiveRequestQueue is sufficient for our mock agent. + from google.adk.agents.live_request_queue import LiveRequestQueue + + live_queue = LiveRequestQueue() + + agen = runner.run_live( + user_id="user", + session_id="missing", + live_request_queue=live_queue, + ) + + event = await agen.__anext__() + await agen.aclose() + + assert event.author == "live_agent" + assert event.content.parts[0].text == "live hello" + + # Session should have been created automatically. + session = await session_service.get_session( + app_name="live_app", user_id="user", session_id="missing" + ) + assert session is not None + + @pytest.mark.asyncio async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch): project_root = tmp_path / "workspace" From 8973618b0b0e90c513873e22af272c147efb4904 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 13 Jan 2026 23:57:12 -0800 Subject: [PATCH 2/3] chore: Add a DebugLoggingPlugin to record human readable debugging logs Co-authored-by: Xuan Yang PiperOrigin-RevId: 856067925 --- .../samples/plugin_debug_logging/__init__.py | 15 + .../samples/plugin_debug_logging/agent.py | 124 ++++ src/google/adk/plugins/__init__.py | 2 + .../adk/plugins/debug_logging_plugin.py | 572 +++++++++++++++++ .../plugins/test_debug_logging_plugin.py | 605 ++++++++++++++++++ 5 files changed, 1318 insertions(+) create mode 100644 contributing/samples/plugin_debug_logging/__init__.py create mode 100644 contributing/samples/plugin_debug_logging/agent.py create mode 100644 src/google/adk/plugins/debug_logging_plugin.py create mode 100644 tests/unittests/plugins/test_debug_logging_plugin.py diff --git a/contributing/samples/plugin_debug_logging/__init__.py b/contributing/samples/plugin_debug_logging/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/plugin_debug_logging/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/plugin_debug_logging/agent.py b/contributing/samples/plugin_debug_logging/agent.py new file mode 100644 index 0000000000..18b345e378 --- /dev/null +++ b/contributing/samples/plugin_debug_logging/agent.py @@ -0,0 +1,124 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample agent demonstrating DebugLoggingPlugin usage. + +This sample shows how to use the DebugLoggingPlugin to capture complete +debug information (LLM requests/responses, tool calls, events, session state) +to a YAML file for debugging purposes. + +Usage: + adk run contributing/samples/plugin_debug_logging + +After running, check the generated `adk_debug.yaml` file for detailed logs. +""" + +from typing import Any + +from google.adk.agents import LlmAgent +from google.adk.apps import App +from google.adk.plugins import DebugLoggingPlugin + + +def get_weather(city: str) -> dict[str, Any]: + """Get the current weather for a city. + + Args: + city: The name of the city to get weather for. + + Returns: + A dictionary containing weather information. + """ + # Simulated weather data + weather_data = { + "new york": {"temperature": 22, "condition": "sunny", "humidity": 45}, + "london": {"temperature": 15, "condition": "cloudy", "humidity": 70}, + "tokyo": {"temperature": 28, "condition": "humid", "humidity": 85}, + "paris": {"temperature": 18, "condition": "rainy", "humidity": 80}, + } + + city_lower = city.lower() + if city_lower in weather_data: + data = weather_data[city_lower] + return { + "city": city, + "temperature_celsius": data["temperature"], + "condition": data["condition"], + "humidity_percent": data["humidity"], + } + else: + return { + "city": city, + "error": f"Weather data not available for {city}", + } + + +def calculate(expression: str) -> dict[str, Any]: + """Evaluate a simple mathematical expression. + + Args: + expression: A mathematical expression to evaluate (e.g., "2 + 2"). + + Returns: + A dictionary containing the result or error. + """ + try: + # Only allow safe mathematical operations + allowed_chars = set("0123456789+-*/.() ") + if not all(c in allowed_chars for c in expression): + return {"error": "Invalid characters in expression"} + + result = eval(expression) # Safe due to character restriction + return {"expression": expression, "result": result} + except Exception as e: + return {"expression": expression, "error": str(e)} + + +# Sample queries to try: +# - "What's the weather in Tokyo?" +# - "Calculate 15 * 7 + 3" +# - "What's the weather in London and calculate 100 / 4" +root_agent = LlmAgent( + name="debug_demo_agent", + description="A demo agent that shows DebugLoggingPlugin capabilities", + instruction="""You are a helpful assistant that can: +1. Get weather information for cities (New York, London, Tokyo, Paris) +2. Perform simple calculations + +When asked about weather, use the get_weather tool. +When asked to calculate, use the calculate tool. +Be concise in your responses.""", + model="gemini-2.0-flash", + tools=[get_weather, calculate], +) + + +# Create the app with DebugLoggingPlugin +# The plugin will write detailed debug information to adk_debug.yaml +app = App( + name="plugin_debug_logging", + root_agent=root_agent, + plugins=[ + # DebugLoggingPlugin captures complete interaction data to a YAML file + # Options: + # output_path: Path to output file (default: "adk_debug.yaml") + # include_session_state: Include session state snapshot (default: True) + # include_system_instruction: Include full system instruction (default: True) + DebugLoggingPlugin( + output_path="adk_debug.yaml", + include_session_state=True, + include_system_instruction=True, + ), + ], +) diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index c824622091..a680a747f6 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -13,12 +13,14 @@ # limitations under the License. from .base_plugin import BasePlugin +from .debug_logging_plugin import DebugLoggingPlugin from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin __all__ = [ 'BasePlugin', + 'DebugLoggingPlugin', 'LoggingPlugin', 'PluginManager', 'ReflectAndRetryToolPlugin', diff --git a/src/google/adk/plugins/debug_logging_plugin.py b/src/google/adk/plugins/debug_logging_plugin.py new file mode 100644 index 0000000000..ef3507a079 --- /dev/null +++ b/src/google/adk/plugins/debug_logging_plugin.py @@ -0,0 +1,572 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Debug logging plugin for capturing complete interaction data to a file.""" + +from __future__ import annotations + +from datetime import datetime +import logging +from pathlib import Path +from typing import Any +from typing import TYPE_CHECKING + +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override +import yaml + +from ..agents.base_agent import BaseAgent +from ..agents.callback_context import CallbackContext +from ..events.event import Event +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..tools.base_tool import BaseTool +from .base_plugin import BasePlugin + +if TYPE_CHECKING: + from ..agents.invocation_context import InvocationContext + from ..tools.tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +class _DebugEntry(BaseModel): + """A single debug log entry.""" + + timestamp: str + entry_type: str + invocation_id: str | None = None + agent_name: str | None = None + data: dict[str, Any] = Field(default_factory=dict) + + +class _InvocationDebugState(BaseModel): + """Per-invocation debug state.""" + + invocation_id: str + session_id: str + app_name: str + user_id: str | None = None + start_time: str + entries: list[_DebugEntry] = Field(default_factory=list) + + +class DebugLoggingPlugin(BasePlugin): + """A plugin that captures complete debug information to a file. + + This plugin records detailed interaction data including: + - LLM requests (model, system instruction, contents, tools) + - LLM responses (content, usage metadata, errors) + - Function calls with arguments + - Function responses with results + - Events yielded from the runner + - Session state at the end of each invocation + + The output is written as YAML format for human readability. Each invocation + is appended to the file as a separate YAML document (separated by ---). + This format is easy to read and can be shared for debugging purposes. + + Example: + >>> debug_plugin = DebugLoggingPlugin(output_path="/tmp/adk_debug.yaml") + >>> runner = Runner( + ... agents=[my_agent], + ... plugins=[debug_plugin], + ... ) + + Attributes: + output_path: Path to the output file. Defaults to "adk_debug.yaml". + include_session_state: Whether to include session state in the output. + include_system_instruction: Whether to include system instructions. + """ + + def __init__( + self, + *, + name: str = "debug_logging_plugin", + output_path: str = "adk_debug.yaml", + include_session_state: bool = True, + include_system_instruction: bool = True, + ): + """Initialize the debug logging plugin. + + Args: + name: The name of the plugin instance. + output_path: Path to the output file. Defaults to "adk_debug.yaml". + include_session_state: Whether to include session state snapshot. + include_system_instruction: Whether to include full system instructions. + """ + super().__init__(name) + self._output_path = Path(output_path) + self._include_session_state = include_session_state + self._include_system_instruction = include_system_instruction + self._invocation_states: dict[str, _InvocationDebugState] = {} + + def _get_timestamp(self) -> str: + """Get current timestamp in ISO format.""" + return datetime.now().isoformat() + + def _serialize_content( + self, content: types.Content | None + ) -> dict[str, Any] | None: + """Serialize Content to a dictionary.""" + if content is None: + return None + + parts = [] + if content.parts: + for part in content.parts: + part_data: dict[str, Any] = {} + if part.text: + part_data["text"] = part.text + if part.function_call: + part_data["function_call"] = { + "id": part.function_call.id, + "name": part.function_call.name, + "args": part.function_call.args, + } + if part.function_response: + part_data["function_response"] = { + "id": part.function_response.id, + "name": part.function_response.name, + "response": self._safe_serialize(part.function_response.response), + } + if part.inline_data: + part_data["inline_data"] = { + "mime_type": part.inline_data.mime_type, + "display_name": getattr(part.inline_data, "display_name", None), + # Omit actual data to keep file size manageable + "_data_omitted": True, + } + if part.file_data: + part_data["file_data"] = { + "file_uri": part.file_data.file_uri, + "mime_type": part.file_data.mime_type, + } + if part.code_execution_result: + part_data["code_execution_result"] = { + "outcome": str(part.code_execution_result.outcome), + "output": part.code_execution_result.output, + } + if part.executable_code: + part_data["executable_code"] = { + "language": str(part.executable_code.language), + "code": part.executable_code.code, + } + if part_data: + parts.append(part_data) + + return {"role": content.role, "parts": parts} + + def _safe_serialize(self, obj: Any) -> Any: + """Safely serialize an object to JSON-compatible format.""" + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (list, tuple)): + return [self._safe_serialize(item) for item in obj] + if isinstance(obj, dict): + return {k: self._safe_serialize(v) for k, v in obj.items()} + if isinstance(obj, BaseModel): + try: + return obj.model_dump(mode="json", exclude_none=True) + except Exception: + return str(obj) + if isinstance(obj, bytes): + return f"" + try: + return str(obj) + except Exception: + return "" + + def _add_entry( + self, + invocation_id: str, + entry_type: str, + agent_name: str | None = None, + **data: Any, + ) -> None: + """Add a debug entry to the current invocation state.""" + if invocation_id not in self._invocation_states: + logger.warning( + "No debug state for invocation %s, skipping entry", invocation_id + ) + return + + entry = _DebugEntry( + timestamp=self._get_timestamp(), + entry_type=entry_type, + invocation_id=invocation_id, + agent_name=agent_name, + data=self._safe_serialize(data), + ) + self._invocation_states[invocation_id].entries.append(entry) + + @override + async def on_user_message_callback( + self, + *, + invocation_context: InvocationContext, + user_message: types.Content, + ) -> types.Content | None: + """Log user message and invocation start.""" + invocation_id = invocation_context.invocation_id + + self._add_entry( + invocation_id, + "user_message", + content=self._serialize_content(user_message), + ) + return None + + @override + async def before_run_callback( + self, *, invocation_context: InvocationContext + ) -> types.Content | None: + """Initialize debug state for this invocation.""" + invocation_id = invocation_context.invocation_id + session = invocation_context.session + + state = _InvocationDebugState( + invocation_id=invocation_id, + session_id=session.id, + app_name=session.app_name, + user_id=invocation_context.user_id, + start_time=self._get_timestamp(), + ) + self._invocation_states[invocation_id] = state + + self._add_entry( + invocation_id, + "invocation_start", + agent_name=getattr(invocation_context.agent, "name", None), + branch=invocation_context.branch, + ) + return None + + @override + async def on_event_callback( + self, *, invocation_context: InvocationContext, event: Event + ) -> Event | None: + """Log events yielded from the runner.""" + invocation_id = invocation_context.invocation_id + + event_data: dict[str, Any] = { + "event_id": event.id, + "author": event.author, + "content": self._serialize_content(event.content), + "is_final_response": event.is_final_response(), + "partial": event.partial, + "turn_complete": event.turn_complete, + "branch": event.branch, + } + + if event.actions: + actions_data: dict[str, Any] = {} + if event.actions.state_delta: + actions_data["state_delta"] = self._safe_serialize( + event.actions.state_delta + ) + if event.actions.artifact_delta: + # Preserve filename -> version mapping for debugging + actions_data["artifact_delta"] = dict(event.actions.artifact_delta) + if event.actions.transfer_to_agent: + actions_data["transfer_to_agent"] = event.actions.transfer_to_agent + if event.actions.escalate: + actions_data["escalate"] = event.actions.escalate + if event.actions.requested_auth_configs: + actions_data["requested_auth_configs"] = len( + event.actions.requested_auth_configs + ) + if actions_data: + event_data["actions"] = actions_data + + if event.grounding_metadata: + event_data["has_grounding_metadata"] = True + + if event.usage_metadata: + event_data["usage_metadata"] = { + "prompt_token_count": event.usage_metadata.prompt_token_count, + "candidates_token_count": event.usage_metadata.candidates_token_count, + "total_token_count": event.usage_metadata.total_token_count, + } + + if event.error_code: + event_data["error_code"] = event.error_code + event_data["error_message"] = event.error_message + + if event.long_running_tool_ids: + event_data["long_running_tool_ids"] = list(event.long_running_tool_ids) + + self._add_entry( + invocation_id, + "event", + agent_name=event.author, + **event_data, + ) + return None + + @override + async def after_run_callback( + self, *, invocation_context: InvocationContext + ) -> None: + """Finalize and write debug data to file.""" + invocation_id = invocation_context.invocation_id + + if invocation_id not in self._invocation_states: + logger.warning( + "No debug state for invocation %s, skipping write", invocation_id + ) + return + + state = self._invocation_states[invocation_id] + + # Add session state snapshot if enabled + if self._include_session_state: + session = invocation_context.session + self._add_entry( + invocation_id, + "session_state_snapshot", + state=self._safe_serialize(session.state), + event_count=len(session.events), + ) + + self._add_entry(invocation_id, "invocation_end") + + # Write to file as YAML + try: + output_data = state.model_dump(mode="json", exclude_none=True) + with self._output_path.open("a", encoding="utf-8") as f: + f.write("---\n") + yaml.dump( + output_data, + f, + default_flow_style=False, + allow_unicode=True, + sort_keys=False, + width=120, + ) + logger.debug( + "Wrote debug data for invocation %s to %s", + invocation_id, + self._output_path, + ) + except Exception as e: + logger.error("Failed to write debug data: %s", e) + finally: + # Cleanup invocation state + self._invocation_states.pop(invocation_id, None) + + @override + async def before_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> types.Content | None: + """Log agent execution start.""" + self._add_entry( + callback_context.invocation_id, + "agent_start", + agent_name=callback_context.agent_name, + branch=callback_context._invocation_context.branch, + ) + return None + + @override + async def after_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> types.Content | None: + """Log agent execution completion.""" + self._add_entry( + callback_context.invocation_id, + "agent_end", + agent_name=callback_context.agent_name, + ) + return None + + @override + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> LlmResponse | None: + """Log LLM request before sending to model.""" + request_data: dict[str, Any] = { + "model": llm_request.model, + "content_count": len(llm_request.contents), + "contents": [self._serialize_content(c) for c in llm_request.contents], + } + + if llm_request.tools_dict: + request_data["tools"] = list(llm_request.tools_dict.keys()) + + if llm_request.config: + config = llm_request.config + config_data: dict[str, Any] = {} + + if self._include_system_instruction and config.system_instruction: + config_data["system_instruction"] = config.system_instruction + elif config.system_instruction: + # Just indicate presence without full content + si = config.system_instruction + if isinstance(si, str): + config_data["system_instruction_length"] = len(si) + else: + config_data["has_system_instruction"] = True + + if config.temperature is not None: + config_data["temperature"] = config.temperature + if config.top_p is not None: + config_data["top_p"] = config.top_p + if config.top_k is not None: + config_data["top_k"] = config.top_k + if config.max_output_tokens is not None: + config_data["max_output_tokens"] = config.max_output_tokens + if config.response_mime_type: + config_data["response_mime_type"] = config.response_mime_type + if config.response_schema: + config_data["has_response_schema"] = True + + if config_data: + request_data["config"] = config_data + + self._add_entry( + callback_context.invocation_id, + "llm_request", + agent_name=callback_context.agent_name, + **request_data, + ) + return None + + @override + async def after_model_callback( + self, *, callback_context: CallbackContext, llm_response: LlmResponse + ) -> LlmResponse | None: + """Log LLM response after receiving from model.""" + response_data: dict[str, Any] = { + "content": self._serialize_content(llm_response.content), + "partial": llm_response.partial, + "turn_complete": llm_response.turn_complete, + } + + if llm_response.error_code: + response_data["error_code"] = llm_response.error_code + response_data["error_message"] = llm_response.error_message + + if llm_response.usage_metadata: + response_data["usage_metadata"] = { + "prompt_token_count": llm_response.usage_metadata.prompt_token_count, + "candidates_token_count": ( + llm_response.usage_metadata.candidates_token_count + ), + "total_token_count": llm_response.usage_metadata.total_token_count, + "cached_content_token_count": ( + llm_response.usage_metadata.cached_content_token_count + ), + } + + if llm_response.grounding_metadata: + response_data["has_grounding_metadata"] = True + + if llm_response.finish_reason: + response_data["finish_reason"] = str(llm_response.finish_reason) + + if llm_response.model_version: + response_data["model_version"] = llm_response.model_version + + self._add_entry( + callback_context.invocation_id, + "llm_response", + agent_name=callback_context.agent_name, + **response_data, + ) + return None + + @override + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> LlmResponse | None: + """Log LLM error.""" + self._add_entry( + callback_context.invocation_id, + "llm_error", + agent_name=callback_context.agent_name, + error_type=type(error).__name__, + error_message=str(error), + model=llm_request.model, + ) + return None + + @override + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> dict[str, Any] | None: + """Log tool execution start.""" + self._add_entry( + tool_context.invocation_id, + "tool_call", + agent_name=tool_context.agent_name, + tool_name=tool.name, + function_call_id=tool_context.function_call_id, + args=self._safe_serialize(tool_args), + ) + return None + + @override + async def after_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + result: dict[str, Any], + ) -> dict[str, Any] | None: + """Log tool execution completion.""" + self._add_entry( + tool_context.invocation_id, + "tool_response", + agent_name=tool_context.agent_name, + tool_name=tool.name, + function_call_id=tool_context.function_call_id, + result=self._safe_serialize(result), + ) + return None + + @override + async def on_tool_error_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> dict[str, Any] | None: + """Log tool error.""" + self._add_entry( + tool_context.invocation_id, + "tool_error", + agent_name=tool_context.agent_name, + tool_name=tool.name, + function_call_id=tool_context.function_call_id, + args=self._safe_serialize(tool_args), + error_type=type(error).__name__, + error_message=str(error), + ) + return None diff --git a/tests/unittests/plugins/test_debug_logging_plugin.py b/tests/unittests/plugins/test_debug_logging_plugin.py new file mode 100644 index 0000000000..a0bb64e948 --- /dev/null +++ b/tests/unittests/plugins/test_debug_logging_plugin.py @@ -0,0 +1,605 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.debug_logging_plugin import DebugLoggingPlugin +from google.adk.sessions.session import Session +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest +import yaml + + +@pytest.fixture +def debug_output_file(tmp_path): + """Fixture to provide a temporary file path for debug output.""" + return tmp_path / "debug_output.yaml" + + +@pytest.fixture +def mock_session(): + """Create a mock session.""" + session = Mock(spec=Session) + session.id = "test-session-id" + session.app_name = "test-app" + session.user_id = "test-user" + session.state = {"key1": "value1", "key2": 123} + session.events = [] + return session + + +@pytest.fixture +def mock_invocation_context(mock_session): + """Create a mock invocation context.""" + ctx = Mock(spec=InvocationContext) + ctx.invocation_id = "test-invocation-id" + ctx.session = mock_session + ctx.user_id = "test-user" + ctx.app_name = "test-app" + ctx.branch = None + ctx.agent = Mock() + ctx.agent.name = "test-agent" + return ctx + + +@pytest.fixture +def mock_callback_context(mock_invocation_context): + """Create a mock callback context.""" + ctx = Mock(spec=CallbackContext) + ctx.invocation_id = mock_invocation_context.invocation_id + ctx.agent_name = "test-agent" + ctx._invocation_context = mock_invocation_context + ctx.state = {} + return ctx + + +@pytest.fixture +def mock_tool_context(mock_invocation_context): + """Create a mock tool context.""" + ctx = Mock(spec=ToolContext) + ctx.invocation_id = mock_invocation_context.invocation_id + ctx.agent_name = "test-agent" + ctx.function_call_id = "test-function-call-id" + return ctx + + +class TestDebugLoggingPluginInitialization: + """Tests for DebugLoggingPlugin initialization.""" + + def test_default_initialization(self): + """Test plugin initialization with default values.""" + plugin = DebugLoggingPlugin() + assert plugin.name == "debug_logging_plugin" + assert plugin._output_path == Path("adk_debug.yaml") + assert plugin._include_session_state is True + assert plugin._include_system_instruction is True + + def test_custom_initialization(self, debug_output_file): + """Test plugin initialization with custom values.""" + plugin = DebugLoggingPlugin( + name="custom_debug", + output_path=str(debug_output_file), + include_session_state=False, + include_system_instruction=False, + ) + assert plugin.name == "custom_debug" + assert plugin._output_path == debug_output_file + assert plugin._include_session_state is False + assert plugin._include_system_instruction is False + + +class TestDebugLoggingPluginCallbacks: + """Tests for DebugLoggingPlugin callback methods.""" + + async def test_before_run_callback_initializes_state( + self, debug_output_file, mock_invocation_context + ): + """Test that before_run_callback initializes debug state.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + result = await plugin.before_run_callback( + invocation_context=mock_invocation_context + ) + + assert result is None + assert mock_invocation_context.invocation_id in plugin._invocation_states + state = plugin._invocation_states[mock_invocation_context.invocation_id] + assert state.invocation_id == mock_invocation_context.invocation_id + assert state.session_id == mock_invocation_context.session.id + assert len(state.entries) == 1 + assert state.entries[0].entry_type == "invocation_start" + + async def test_on_user_message_callback_logs_message( + self, debug_output_file, mock_invocation_context + ): + """Test that on_user_message_callback logs user messages.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + user_message = types.Content( + role="user", parts=[types.Part.from_text(text="Hello, world!")] + ) + + result = await plugin.on_user_message_callback( + invocation_context=mock_invocation_context, user_message=user_message + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + user_message_entries = [ + e for e in state.entries if e.entry_type == "user_message" + ] + assert len(user_message_entries) == 1 + assert user_message_entries[0].data["content"]["role"] == "user" + assert user_message_entries[0].data["content"]["parts"][0]["text"] == ( + "Hello, world!" + ) + + async def test_before_model_callback_logs_request( + self, debug_output_file, mock_invocation_context, mock_callback_context + ): + """Test that before_model_callback logs LLM requests.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + llm_request = LlmRequest( + model="gemini-2.0-flash", + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + ) + llm_request.config.system_instruction = "You are a helpful assistant." + + result = await plugin.before_model_callback( + callback_context=mock_callback_context, llm_request=llm_request + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + llm_entries = [e for e in state.entries if e.entry_type == "llm_request"] + assert len(llm_entries) == 1 + assert llm_entries[0].data["model"] == "gemini-2.0-flash" + assert llm_entries[0].data["content_count"] == 1 + assert "config" in llm_entries[0].data + assert ( + llm_entries[0].data["config"]["system_instruction"] + == "You are a helpful assistant." + ) + + async def test_after_model_callback_logs_response( + self, debug_output_file, mock_invocation_context, mock_callback_context + ): + """Test that after_model_callback logs LLM responses.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + llm_response = LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Hello! How can I help?")], + ), + turn_complete=True, + ) + + result = await plugin.after_model_callback( + callback_context=mock_callback_context, llm_response=llm_response + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + llm_entries = [e for e in state.entries if e.entry_type == "llm_response"] + assert len(llm_entries) == 1 + assert llm_entries[0].data["turn_complete"] is True + assert llm_entries[0].data["content"]["role"] == "model" + + async def test_before_tool_callback_logs_tool_call( + self, debug_output_file, mock_invocation_context, mock_tool_context + ): + """Test that before_tool_callback logs tool calls.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + tool_args = {"param1": "value1", "param2": 42} + + result = await plugin.before_tool_callback( + tool=mock_tool, tool_args=tool_args, tool_context=mock_tool_context + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + tool_entries = [e for e in state.entries if e.entry_type == "tool_call"] + assert len(tool_entries) == 1 + assert tool_entries[0].data["tool_name"] == "test_tool" + assert tool_entries[0].data["args"]["param1"] == "value1" + assert tool_entries[0].data["args"]["param2"] == 42 + + async def test_after_tool_callback_logs_tool_response( + self, debug_output_file, mock_invocation_context, mock_tool_context + ): + """Test that after_tool_callback logs tool responses.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + tool_args = {"param1": "value1"} + result_data = {"output": "success", "data": [1, 2, 3]} + + result = await plugin.after_tool_callback( + tool=mock_tool, + tool_args=tool_args, + tool_context=mock_tool_context, + result=result_data, + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + tool_entries = [e for e in state.entries if e.entry_type == "tool_response"] + assert len(tool_entries) == 1 + assert tool_entries[0].data["tool_name"] == "test_tool" + assert tool_entries[0].data["result"]["output"] == "success" + + async def test_on_event_callback_logs_event( + self, debug_output_file, mock_invocation_context + ): + """Test that on_event_callback logs events.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + event = Event( + author="test-agent", + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Response text")], + ), + ) + + result = await plugin.on_event_callback( + invocation_context=mock_invocation_context, event=event + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + event_entries = [e for e in state.entries if e.entry_type == "event"] + assert len(event_entries) == 1 + assert event_entries[0].data["author"] == "test-agent" + assert event_entries[0].data["event_id"] == event.id + + async def test_on_model_error_callback_logs_error( + self, debug_output_file, mock_invocation_context, mock_callback_context + ): + """Test that on_model_error_callback logs LLM errors.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + llm_request = LlmRequest(model="gemini-2.0-flash") + error = ValueError("Test error message") + + result = await plugin.on_model_error_callback( + callback_context=mock_callback_context, + llm_request=llm_request, + error=error, + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + error_entries = [e for e in state.entries if e.entry_type == "llm_error"] + assert len(error_entries) == 1 + assert error_entries[0].data["error_type"] == "ValueError" + assert error_entries[0].data["error_message"] == "Test error message" + + async def test_on_tool_error_callback_logs_error( + self, debug_output_file, mock_invocation_context, mock_tool_context + ): + """Test that on_tool_error_callback logs tool errors.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state first + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + mock_tool = Mock(spec=BaseTool) + mock_tool.name = "test_tool" + tool_args = {"param1": "value1"} + error = RuntimeError("Tool execution failed") + + result = await plugin.on_tool_error_callback( + tool=mock_tool, + tool_args=tool_args, + tool_context=mock_tool_context, + error=error, + ) + + assert result is None + state = plugin._invocation_states[mock_invocation_context.invocation_id] + error_entries = [e for e in state.entries if e.entry_type == "tool_error"] + assert len(error_entries) == 1 + assert error_entries[0].data["tool_name"] == "test_tool" + assert error_entries[0].data["error_type"] == "RuntimeError" + + +class TestDebugLoggingPluginFileOutput: + """Tests for DebugLoggingPlugin file output.""" + + async def test_after_run_callback_writes_to_file( + self, debug_output_file, mock_invocation_context + ): + """Test that after_run_callback writes debug data to file.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # Initialize state + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + # Add some entries + user_message = types.Content( + role="user", parts=[types.Part.from_text(text="Test message")] + ) + await plugin.on_user_message_callback( + invocation_context=mock_invocation_context, user_message=user_message + ) + + # Finalize + await plugin.after_run_callback(invocation_context=mock_invocation_context) + + # Verify file was written + assert debug_output_file.exists() + + # Parse and verify content (YAML format with --- separator) + with open(debug_output_file, "r") as f: + documents = list(yaml.safe_load_all(f)) + + assert len(documents) == 1 + data = documents[0] + assert data["invocation_id"] == "test-invocation-id" + assert data["session_id"] == "test-session-id" + assert ( + len(data["entries"]) >= 2 + ) # At least invocation_start and user_message + + async def test_after_run_callback_includes_session_state( + self, debug_output_file, mock_invocation_context + ): + """Test that session state is included when enabled.""" + plugin = DebugLoggingPlugin( + output_path=str(debug_output_file), include_session_state=True + ) + + await plugin.before_run_callback(invocation_context=mock_invocation_context) + await plugin.after_run_callback(invocation_context=mock_invocation_context) + + with open(debug_output_file, "r") as f: + documents = list(yaml.safe_load_all(f)) + + data = documents[0] + session_state_entries = [ + e + for e in data["entries"] + if e["entry_type"] == "session_state_snapshot" + ] + assert len(session_state_entries) == 1 + assert session_state_entries[0]["data"]["state"]["key1"] == "value1" + + async def test_after_run_callback_excludes_session_state_when_disabled( + self, debug_output_file, mock_invocation_context + ): + """Test that session state is excluded when disabled.""" + plugin = DebugLoggingPlugin( + output_path=str(debug_output_file), include_session_state=False + ) + + await plugin.before_run_callback(invocation_context=mock_invocation_context) + await plugin.after_run_callback(invocation_context=mock_invocation_context) + + with open(debug_output_file, "r") as f: + documents = list(yaml.safe_load_all(f)) + + data = documents[0] + session_state_entries = [ + e + for e in data["entries"] + if e["entry_type"] == "session_state_snapshot" + ] + assert not session_state_entries + + async def test_multiple_invocations_append_to_file( + self, debug_output_file, mock_session + ): + """Test that multiple invocations append to the same file.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + # First invocation + ctx1 = Mock(spec=InvocationContext) + ctx1.invocation_id = "invocation-1" + ctx1.session = mock_session + ctx1.user_id = "test-user" + ctx1.branch = None + ctx1.agent = Mock() + ctx1.agent.name = "agent-1" + + await plugin.before_run_callback(invocation_context=ctx1) + await plugin.after_run_callback(invocation_context=ctx1) + + # Second invocation + ctx2 = Mock(spec=InvocationContext) + ctx2.invocation_id = "invocation-2" + ctx2.session = mock_session + ctx2.user_id = "test-user" + ctx2.branch = None + ctx2.agent = Mock() + ctx2.agent.name = "agent-2" + + await plugin.before_run_callback(invocation_context=ctx2) + await plugin.after_run_callback(invocation_context=ctx2) + + # Verify both invocations are in the file (as separate YAML documents) + with open(debug_output_file, "r") as f: + documents = list(yaml.safe_load_all(f)) + + assert len(documents) == 2 + assert documents[0]["invocation_id"] == "invocation-1" + assert documents[1]["invocation_id"] == "invocation-2" + + async def test_after_run_callback_cleans_up_state( + self, debug_output_file, mock_invocation_context + ): + """Test that invocation state is cleaned up after writing.""" + plugin = DebugLoggingPlugin(output_path=str(debug_output_file)) + + await plugin.before_run_callback(invocation_context=mock_invocation_context) + assert mock_invocation_context.invocation_id in plugin._invocation_states + + await plugin.after_run_callback(invocation_context=mock_invocation_context) + assert ( + mock_invocation_context.invocation_id not in plugin._invocation_states + ) + + +class TestDebugLoggingPluginSerialization: + """Tests for content serialization.""" + + def test_serialize_content_with_text(self): + """Test serialization of text content.""" + plugin = DebugLoggingPlugin() + content = types.Content( + role="user", parts=[types.Part.from_text(text="Hello")] + ) + + result = plugin._serialize_content(content) + + assert result["role"] == "user" + assert len(result["parts"]) == 1 + assert result["parts"][0]["text"] == "Hello" + + def test_serialize_content_with_function_call(self): + """Test serialization of function call content.""" + plugin = DebugLoggingPlugin() + content = types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + id="fc-1", name="test_func", args={"arg1": "val1"} + ) + ) + ], + ) + + result = plugin._serialize_content(content) + + assert result["parts"][0]["function_call"]["name"] == "test_func" + assert result["parts"][0]["function_call"]["args"]["arg1"] == "val1" + + def test_serialize_content_with_none(self): + """Test serialization of None content.""" + plugin = DebugLoggingPlugin() + result = plugin._serialize_content(None) + assert result is None + + def test_safe_serialize_handles_bytes(self): + """Test that bytes are safely serialized.""" + plugin = DebugLoggingPlugin() + result = plugin._safe_serialize(b"binary data") + assert result == "" + + def test_safe_serialize_handles_nested_structures(self): + """Test that nested structures are serialized.""" + plugin = DebugLoggingPlugin() + data = { + "list": [1, 2, {"nested": "value"}], + "tuple": (3, 4), + "string": "text", + } + + result = plugin._safe_serialize(data) + + assert result["list"] == [1, 2, {"nested": "value"}] + assert result["tuple"] == [3, 4] # Tuple becomes list + assert result["string"] == "text" + + +class TestDebugLoggingPluginSystemInstructionConfig: + """Tests for system instruction configuration.""" + + async def test_system_instruction_included_when_enabled( + self, debug_output_file, mock_invocation_context, mock_callback_context + ): + """Test that full system instruction is included when enabled.""" + plugin = DebugLoggingPlugin( + output_path=str(debug_output_file), include_system_instruction=True + ) + + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + llm_request = LlmRequest(model="gemini-2.0-flash") + llm_request.config.system_instruction = "Full system instruction text" + + await plugin.before_model_callback( + callback_context=mock_callback_context, llm_request=llm_request + ) + + state = plugin._invocation_states[mock_invocation_context.invocation_id] + llm_entries = [e for e in state.entries if e.entry_type == "llm_request"] + assert ( + llm_entries[0].data["config"]["system_instruction"] + == "Full system instruction text" + ) + + async def test_system_instruction_length_only_when_disabled( + self, debug_output_file, mock_invocation_context, mock_callback_context + ): + """Test that only length is included when system instruction is disabled.""" + plugin = DebugLoggingPlugin( + output_path=str(debug_output_file), include_system_instruction=False + ) + + await plugin.before_run_callback(invocation_context=mock_invocation_context) + + llm_request = LlmRequest(model="gemini-2.0-flash") + llm_request.config.system_instruction = "Full system instruction text" + + await plugin.before_model_callback( + callback_context=mock_callback_context, llm_request=llm_request + ) + + state = plugin._invocation_states[mock_invocation_context.invocation_id] + llm_entries = [e for e in state.entries if e.entry_type == "llm_request"] + assert "system_instruction" not in llm_entries[0].data.get("config", {}) + assert llm_entries[0].data["config"]["system_instruction_length"] == 28 From 79fcddb39f71a4c1342e63b4d67832b3eccb2652 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 13 Jan 2026 23:57:26 -0800 Subject: [PATCH 3/3] feat: Add `--enable_features` CLI option to ADK CLI This flag can be used to override default feature enable state. Co-authored-by: Xuan Yang PiperOrigin-RevId: 856067979 --- src/google/adk/cli/cli_tools_click.py | 56 +++++ .../unittests/cli/test_cli_feature_options.py | 197 ++++++++++++++++++ .../test_cli_tools_click_option_mismatch.py | 15 +- 3 files changed, 264 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/cli/test_cli_feature_options.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 4aa39dce9c..5d7611f217 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -36,6 +36,8 @@ from . import cli_deploy from .. import version from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..features import FeatureName +from ..features import override_feature_enabled from .cli import run_cli from .fast_api import get_fast_api_app from .utils import envs @@ -48,6 +50,56 @@ ) +def _apply_feature_overrides(enable_features: tuple[str, ...]) -> None: + """Apply feature overrides from CLI flags. + + Args: + enable_features: Tuple of feature names to enable. + """ + for features_str in enable_features: + for feature_name_str in features_str.split(","): + feature_name_str = feature_name_str.strip() + if not feature_name_str: + continue + try: + feature_name = FeatureName(feature_name_str) + override_feature_enabled(feature_name, True) + except ValueError: + valid_names = ", ".join(f.value for f in FeatureName) + click.secho( + f"WARNING: Unknown feature name '{feature_name_str}'. " + f"Valid names are: {valid_names}", + fg="yellow", + err=True, + ) + + +def feature_options(): + """Decorator to add feature override options to click commands.""" + + def decorator(func): + @click.option( + "--enable_features", + help=( + "Optional. Comma-separated list of feature names to enable. " + "This provides an alternative to environment variables for " + "enabling experimental features. Example: " + "--enable_features=JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING" + ), + multiple=True, + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + enable_features = kwargs.pop("enable_features", ()) + if enable_features: + _apply_feature_overrides(enable_features) + return func(*args, **kwargs) + + return wrapper + + return decorator + + class HelpfulCommand(click.Command): """Command that shows full help on error instead of just the error message. @@ -451,6 +503,7 @@ def wrapper(*args, **kwargs): @main.command("run", cls=HelpfulCommand) +@feature_options() @adk_services_options(default_use_local_storage=True) @click.option( "--save_session", @@ -576,6 +629,7 @@ def wrapper(*args, **kwargs): @main.command("eval", cls=HelpfulCommand) +@feature_options() @click.argument( "agent_module_file_path", type=click.Path( @@ -1141,6 +1195,7 @@ def wrapper(ctx, *args, **kwargs): @main.command("web") +@feature_options() @fast_api_common_options() @web_options() @adk_services_options(default_use_local_storage=True) @@ -1243,6 +1298,7 @@ async def _lifespan(app: FastAPI): @main.command("api_server") +@feature_options() # The directory of agents, where each sub-directory is a single agent. # By default, it is the current working directory @click.argument( diff --git a/tests/unittests/cli/test_cli_feature_options.py b/tests/unittests/cli/test_cli_feature_options.py new file mode 100644 index 0000000000..70bfec2dda --- /dev/null +++ b/tests/unittests/cli/test_cli_feature_options.py @@ -0,0 +1,197 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for --enable_features CLI option.""" + +from __future__ import annotations + +import click +from click.testing import CliRunner +from google.adk.cli.cli_tools_click import _apply_feature_overrides +from google.adk.cli.cli_tools_click import feature_options +from google.adk.features._feature_registry import _FEATURE_OVERRIDES +from google.adk.features._feature_registry import _WARNED_FEATURES +from google.adk.features._feature_registry import FeatureName +from google.adk.features._feature_registry import is_feature_enabled +import pytest + + +@pytest.fixture(autouse=True) +def reset_feature_overrides(): + """Reset feature overrides and warnings before/after each test.""" + _FEATURE_OVERRIDES.clear() + _WARNED_FEATURES.clear() + yield + _FEATURE_OVERRIDES.clear() + _WARNED_FEATURES.clear() + + +class TestApplyFeatureOverrides: + """Tests for _apply_feature_overrides helper function.""" + + def test_single_feature(self): + """Single feature name is applied correctly.""" + _apply_feature_overrides(("JSON_SCHEMA_FOR_FUNC_DECL",)) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + def test_comma_separated_features(self): + """Comma-separated feature names are applied correctly.""" + _apply_feature_overrides(( + "JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING", + )) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + + def test_multiple_flag_values(self): + """Multiple --enable_features flags are applied correctly.""" + _apply_feature_overrides(( + "JSON_SCHEMA_FOR_FUNC_DECL", + "PROGRESSIVE_SSE_STREAMING", + )) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + + def test_whitespace_handling(self): + """Whitespace around feature names is stripped.""" + _apply_feature_overrides((" JSON_SCHEMA_FOR_FUNC_DECL , COMPUTER_USE ",)) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert is_feature_enabled(FeatureName.COMPUTER_USE) + + def test_empty_string_ignored(self): + """Empty strings in the list are ignored.""" + _apply_feature_overrides(("",)) + # No error should be raised + + def test_unknown_feature_warns(self, capsys): + """Unknown feature names emit a warning.""" + _apply_feature_overrides(("UNKNOWN_FEATURE_XYZ",)) + captured = capsys.readouterr() + assert "WARNING" in captured.err + assert "UNKNOWN_FEATURE_XYZ" in captured.err + assert "Valid names are:" in captured.err + + +class TestFeatureOptionsDecorator: + """Tests for feature_options decorator.""" + + def test_decorator_adds_enable_features_option(self): + """Decorator adds --enable_features option to command.""" + + @click.command() + @feature_options() + def test_cmd(): + pass + + runner = CliRunner() + result = runner.invoke(test_cmd, ["--help"]) + assert "--enable_features" in result.output + + def test_enable_features_applied_before_command(self): + """Features are enabled before the command function runs.""" + feature_was_enabled = [] + + @click.command() + @feature_options() + def test_cmd(): + feature_was_enabled.append( + is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + ["--enable_features=JSON_SCHEMA_FOR_FUNC_DECL"], + catch_exceptions=False, + ) + assert feature_was_enabled == [True] + + def test_multiple_enable_features_flags(self): + """Multiple --enable_features flags work correctly.""" + enabled_features = [] + + @click.command() + @feature_options() + def test_cmd(): + enabled_features.append( + is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + enabled_features.append( + is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + [ + "--enable_features=JSON_SCHEMA_FOR_FUNC_DECL", + "--enable_features=PROGRESSIVE_SSE_STREAMING", + ], + catch_exceptions=False, + ) + assert enabled_features == [True, True] + + def test_comma_separated_enable_features(self): + """Comma-separated feature names work correctly.""" + enabled_features = [] + + @click.command() + @feature_options() + def test_cmd(): + enabled_features.append( + is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + enabled_features.append( + is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + [ + "--enable_features=JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING" + ], + catch_exceptions=False, + ) + assert enabled_features == [True, True] + + def test_no_enable_features_flag(self): + """Command works without --enable_features flag.""" + enabled_features = [] + + @click.command() + @feature_options() + def test_cmd(): + enabled_features.append( + is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + + runner = CliRunner() + result = runner.invoke(test_cmd, [], catch_exceptions=False) + assert result.exit_code == 0 + assert enabled_features == [False] + + def test_preserves_function_metadata(self): + """Decorator preserves the wrapped function's metadata.""" + + @click.command() + @feature_options() + def my_test_command(): + """My docstring.""" + pass + + # The callback should have preserved metadata + assert ( + "my_test_command" in my_test_command.name + or my_test_command.callback.__name__ == "my_test_command" + ) diff --git a/tests/unittests/cli/test_cli_tools_click_option_mismatch.py b/tests/unittests/cli/test_cli_tools_click_option_mismatch.py index 346fd421d0..3c67e9ae39 100644 --- a/tests/unittests/cli/test_cli_tools_click_option_mismatch.py +++ b/tests/unittests/cli/test_cli_tools_click_option_mismatch.py @@ -94,7 +94,9 @@ def test_adk_run(): run_command = _get_command_by_name(main.commands, "run") assert run_command is not None, "Run command not found" - _check_options_in_parameters(run_command, cli_run.callback, "run") + _check_options_in_parameters( + run_command, cli_run.callback, "run", ignore_params={"enable_features"} + ) def test_adk_eval(): @@ -102,7 +104,9 @@ def test_adk_eval(): eval_command = _get_command_by_name(main.commands, "eval") assert eval_command is not None, "Eval command not found" - _check_options_in_parameters(eval_command, cli_eval.callback, "eval") + _check_options_in_parameters( + eval_command, cli_eval.callback, "eval", ignore_params={"enable_features"} + ) def test_adk_web(): @@ -111,7 +115,10 @@ def test_adk_web(): assert web_command is not None, "Web command not found" _check_options_in_parameters( - web_command, cli_web.callback, "web", ignore_params={"verbose"} + web_command, + cli_web.callback, + "web", + ignore_params={"verbose", "enable_features"}, ) @@ -124,7 +131,7 @@ def test_adk_api_server(): api_server_command, cli_api_server.callback, "api_server", - ignore_params={"verbose"}, + ignore_params={"verbose", "enable_features"}, )