From e02b9fb608f645cdcf8300ffe4e64627df10ba28 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 3 Dec 2025 09:46:42 -0800 Subject: [PATCH 1/4] fix: Add a warning when deploying with the ADK Web UI enabled The warning message shows that ADK Web is for development purposes only and should not be used in production, as it has access to all data. This warning is displayed when the `--with-ui` flag is used with `adk deploy` and `adk deploy to-gke` Co-authored-by: George Weale PiperOrigin-RevId: 839795361 --- src/google/adk/cli/cli_tools_click.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 6c3e7b98a9..c4446278b4 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -110,6 +110,18 @@ def parse_args(self, ctx, args): logger = logging.getLogger("google_adk." + __name__) +_ADK_WEB_WARNING = ( + "ADK Web is for development purposes. It has access to all data and" + " should not be used in production." +) + + +def _warn_if_with_ui(with_ui: bool) -> None: + """Warn when deploying with the developer UI enabled.""" + if with_ui: + click.secho(f"WARNING: {_ADK_WEB_WARNING}", fg="yellow", err=True) + + @click.group(context_settings={"max_content_width": 240}) @click.version_option(version.__version__) def main(): @@ -1429,6 +1441,8 @@ def cli_deploy_cloud_run( err=True, ) + _warn_if_with_ui(with_ui) + session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri @@ -1848,6 +1862,7 @@ def cli_deploy_gke( --cluster_name=[cluster_name] path/to/my_agent """ try: + _warn_if_with_ui(with_ui) cli_deploy.to_gke( agent_folder=agent, project=project, From 8c9105bf14f57606a73753654922fe26f584dff6 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 3 Dec 2025 09:55:52 -0800 Subject: [PATCH 2/4] chore: Drop Python 3.9 support, set minimum to Python 3.10 Co-authored-by: George Weale PiperOrigin-RevId: 839799108 --- .github/workflows/python-unit-tests.yml | 16 ++--- .../sample-output/alembic.ini | 2 +- contributing/samples/telemetry/main.py | 37 +++++----- llms-full.txt | 6 +- .../adk/a2a/converters/part_converter.py | 16 +---- .../adk/a2a/converters/request_converter.py | 15 +--- .../adk/a2a/executor/a2a_agent_executor.py | 36 ++++------ .../adk/a2a/utils/agent_card_builder.py | 20 ++---- src/google/adk/a2a/utils/agent_to_a2a.py | 23 ++---- src/google/adk/agents/__init__.py | 18 +---- .../adk/agents/mcp_instruction_provider.py | 18 +---- src/google/adk/agents/parallel_agent.py | 70 ++++++++---------- src/google/adk/agents/remote_a2a_agent.py | 43 +++++------ src/google/adk/cli/fast_api.py | 27 +++---- src/google/adk/tools/crewai_tool.py | 13 +--- src/google/adk/tools/mcp_tool/__init__.py | 12 +--- .../adk/tools/mcp_tool/mcp_session_manager.py | 21 ++---- src/google/adk/tools/mcp_tool/mcp_tool.py | 27 ++----- src/google/adk/tools/mcp_tool/mcp_toolset.py | 19 +---- src/google/adk/utils/context_utils.py | 29 +------- .../a2a/converters/test_event_converter.py | 61 ++++++---------- .../a2a/converters/test_part_converter.py | 37 +++------- .../a2a/converters/test_request_converter.py | 27 ++----- tests/unittests/a2a/converters/test_utils.py | 28 ++------ .../a2a/executor/test_a2a_agent_executor.py | 41 ++++------- .../executor/test_task_result_aggregator.py | 33 +++------ .../a2a/utils/test_agent_card_builder.py | 71 +++++++----------- .../unittests/a2a/utils/test_agent_to_a2a.py | 45 ++++-------- .../agents/test_mcp_instruction_provider.py | 22 +----- .../unittests/agents/test_remote_a2a_agent.py | 72 ++++++------------- tests/unittests/cli/test_fast_api.py | 11 --- .../evaluation/test_local_eval_service.py | 3 - .../plugins/test_reflect_retry_tool_plugin.py | 6 +- tests/unittests/telemetry/test_functional.py | 2 +- .../computer_use/test_computer_use_tool.py | 2 +- .../mcp_tool/test_mcp_session_manager.py | 44 ++---------- .../unittests/tools/mcp_tool/test_mcp_tool.py | 39 ++-------- .../tools/mcp_tool/test_mcp_toolset.py | 46 +++--------- .../tools/retrieval/test_files_retrieval.py | 4 -- tests/unittests/tools/test_mcp_toolset.py | 21 +----- 40 files changed, 279 insertions(+), 804 deletions(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 8f8f46e953..3fc6bd943f 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -25,7 +25,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - name: Checkout code @@ -48,14 +48,6 @@ jobs: - name: Run unit tests with pytest run: | source .venv/bin/activate - if [[ "${{ matrix.python-version }}" == "3.9" ]]; then - pytest tests/unittests \ - --ignore=tests/unittests/a2a \ - --ignore=tests/unittests/tools/mcp_tool \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py - else - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py - fi \ No newline at end of file + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \ No newline at end of file diff --git a/contributing/samples/migrate_session_db/sample-output/alembic.ini b/contributing/samples/migrate_session_db/sample-output/alembic.ini index 6405320948..e346ee8ac6 100644 --- a/contributing/samples/migrate_session_db/sample-output/alembic.ini +++ b/contributing/samples/migrate_session_db/sample-output/alembic.ini @@ -21,7 +21,7 @@ prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# If specified, requires the python>=3.10 and tzdata library. # Any required deps can installed by adding `alembic[tz]` to the pip requirements # string value is passed to ZoneInfo() # leave blank for localtime diff --git a/contributing/samples/telemetry/main.py b/contributing/samples/telemetry/main.py index e580060dc4..c6e05f0f62 100755 --- a/contributing/samples/telemetry/main.py +++ b/contributing/samples/telemetry/main.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from contextlib import aclosing import os import time @@ -46,19 +47,16 @@ async def run_prompt(session: Session, new_message: str): role='user', parts=[types.Part.from_text(text=new_message)] ) print('** User says:', content.model_dump(exclude_none=True)) - # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is - # no longer supported. - agen = runner.run_async( - user_id=user_id_1, - session_id=session.id, - new_message=content, - ) - try: + async with aclosing( + runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ) + ) as agen: async for event in agen: if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') - finally: - await agen.aclose() async def run_prompt_bytes(session: Session, new_message: str): content = types.Content( @@ -70,20 +68,17 @@ async def run_prompt_bytes(session: Session, new_message: str): ], ) print('** User says:', content.model_dump(exclude_none=True)) - # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is - # no longer supported. - agen = runner.run_async( - user_id=user_id_1, - session_id=session.id, - new_message=content, - run_config=RunConfig(save_input_blobs_as_artifacts=True), - ) - try: + async with aclosing( + runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=True), + ) + ) as agen: async for event in agen: if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') - finally: - await agen.aclose() start_time = time.time() print('Start time:', start_time) diff --git a/llms-full.txt b/llms-full.txt index 4c744512e4..b84e9496ee 100644 --- a/llms-full.txt +++ b/llms-full.txt @@ -5620,7 +5620,7 @@ pip install google-cloud-aiplatform[adk,agent_engines] ``` !!!info - Agent Engine only supported Python version >=3.9 and <=3.12. + Agent Engine only supported Python version >=3.10 and <=3.12. ### Initialization @@ -8073,7 +8073,7 @@ setting up a basic agent with multiple tools, and running it locally either in t This quickstart assumes a local IDE (VS Code, PyCharm, IntelliJ IDEA, etc.) -with Python 3.9+ or Java 17+ and terminal access. This method runs the +with Python 3.10+ or Java 17+ and terminal access. This method runs the application entirely on your machine and is recommended for internal development. ## 1. Set up Environment & Install ADK {#venv-install} @@ -16475,7 +16475,7 @@ This guide covers two primary integration patterns: Before you begin, ensure you have the following set up: * **Set up ADK:** Follow the standard ADK [setup instructions](../get-started/quickstart.md/#venv-install) in the quickstart. -* **Install/update Python/Java:** MCP requires Python version of 3.9 or higher for Python or Java 17+. +* **Install/update Python/Java:** MCP requires Python version of 3.10 or higher for Python or Java 17+. * **Setup Node.js and npx:** **(Python only)** Many community MCP servers are distributed as Node.js packages and run using `npx`. Install Node.js (which includes npx) if you haven't already. For details, see [https://nodejs.org/en](https://nodejs.org/en). * **Verify Installations:** **(Python only)** Confirm `adk` and `npx` are in your PATH within the activated virtual environment: diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index a21042cc10..dfe6f4a0a2 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -26,23 +26,11 @@ from typing import Optional from typing import Union -from .utils import _get_adk_metadata_key - -try: - from a2a import types as a2a_types -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a import types as a2a_types from google.genai import types as genai_types from ..experimental import a2a_experimental +from .utils import _get_adk_metadata_key logger = logging.getLogger('google_adk.' + __name__) diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 39db41dac6..1746ec0bca 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -15,23 +15,12 @@ from __future__ import annotations from collections.abc import Callable -import sys from typing import Any from typing import Optional -from pydantic import BaseModel - -try: - from a2a.server.agent_execution import RequestContext -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a.server.agent_execution import RequestContext from google.genai import types as genai_types +from pydantic import BaseModel from ...runners import RunConfig from ..experimental import a2a_experimental diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 608a818864..b6880aaa5c 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -23,34 +23,22 @@ from typing import Optional import uuid -from ...utils.context_utils import Aclosing - -try: - from a2a.server.agent_execution import AgentExecutor - from a2a.server.agent_execution.context import RequestContext - from a2a.server.events.event_queue import EventQueue - from a2a.types import Artifact - from a2a.types import Message - from a2a.types import Role - from a2a.types import TaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart from google.adk.runners import Runner from pydantic import BaseModel from typing_extensions import override +from ...utils.context_utils import Aclosing from ..converters.event_converter import AdkEventToA2AEventsConverter from ..converters.event_converter import convert_event_to_a2a_events from ..converters.part_converter import A2APartToGenAIPartConverter diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py index aa7f657f99..c007870931 100644 --- a/src/google/adk/a2a/utils/agent_card_builder.py +++ b/src/google/adk/a2a/utils/agent_card_builder.py @@ -15,25 +15,15 @@ from __future__ import annotations import re -import sys from typing import Dict from typing import List from typing import Optional -try: - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentProvider - from a2a.types import AgentSkill - from a2a.types import SecurityScheme -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentProvider +from a2a.types import AgentSkill +from a2a.types import SecurityScheme from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 72a2292fb3..1a1ba35618 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -15,30 +15,18 @@ from __future__ import annotations import logging -import sys - -try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python version." - ) from e - else: - raise e - from typing import Optional from typing import Union +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard from starlette.applications import Starlette from ...agents.base_agent import BaseAgent from ...artifacts.in_memory_artifact_service import InMemoryArtifactService from ...auth.credential_service.in_memory_credential_service import InMemoryCredentialService -from ...cli.utils.logs import setup_adk_logger from ...memory.in_memory_memory_service import InMemoryMemoryService from ...runners import Runner from ...sessions.in_memory_session_service import InMemorySessionService @@ -117,7 +105,8 @@ def to_a2a( app = to_a2a(agent, agent_card=my_custom_agent_card) """ # Set up ADK logging to ensure logs are visible when using uvicorn directly - setup_adk_logger(logging.INFO) + adk_logger = logging.getLogger("google_adk") + adk_logger.setLevel(logging.INFO) async def create_runner() -> Runner: """Create a runner for the agent.""" diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index 5710a21b7f..b5f8e88cde 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - from .base_agent import BaseAgent from .invocation_context import InvocationContext from .live_request_queue import LiveRequest @@ -22,6 +19,7 @@ from .llm_agent import Agent from .llm_agent import LlmAgent from .loop_agent import LoopAgent +from .mcp_instruction_provider import McpInstructionProvider from .parallel_agent import ParallelAgent from .run_config import RunConfig from .sequential_agent import SequentialAgent @@ -31,6 +29,7 @@ 'BaseAgent', 'LlmAgent', 'LoopAgent', + 'McpInstructionProvider', 'ParallelAgent', 'SequentialAgent', 'InvocationContext', @@ -38,16 +37,3 @@ 'LiveRequestQueue', 'RunConfig', ] - -if sys.version_info < (3, 10): - logger = logging.getLogger('google_adk.' + __name__) - logger.warning( - 'MCP requires Python 3.10 or above. Please upgrade your Python' - ' version in order to use it.' - ) -else: - from .mcp_instruction_provider import McpInstructionProvider - - __all__.extend([ - 'McpInstructionProvider', - ]) diff --git a/src/google/adk/agents/mcp_instruction_provider.py b/src/google/adk/agents/mcp_instruction_provider.py index e9f40663c9..20896a7a04 100644 --- a/src/google/adk/agents/mcp_instruction_provider.py +++ b/src/google/adk/agents/mcp_instruction_provider.py @@ -22,24 +22,12 @@ from typing import Dict from typing import TextIO +from mcp import types + +from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager from .llm_agent import InstructionProvider from .readonly_context import ReadonlyContext -# Attempt to import MCP Session Manager from the MCP library, and hints user to -# upgrade their Python version to 3.10 if it fails. -try: - from mcp import types - - from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "MCP Session Manager requires Python 3.10 or above. Please upgrade" - " your Python version." - ) from e - else: - raise e - class McpInstructionProvider(InstructionProvider): """Fetches agent instructions from an MCP server.""" diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index f7270a75c9..09e65a67a4 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -48,34 +48,13 @@ def _create_branch_ctx_for_sub_agent( return invocation_context -# TODO - remove once Python <3.11 is no longer supported. -async def _merge_agent_run_pre_3_11( +async def _merge_agent_run( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: - """Merges the agent run event generator. - This version works in Python 3.9 and 3.10 and uses custom replacement for - asyncio.TaskGroup for tasks cancellation and exception handling. - - This implementation guarantees for each agent, it won't move on until the - generated event is processed by upstream runner. - - Args: - agent_runs: A list of async generators that yield events from each agent. - - Yields: - Event: The next event from the merged generator. - """ + """Merges agent runs using asyncio.TaskGroup on Python 3.11+.""" sentinel = object() queue = asyncio.Queue() - def propagate_exceptions(tasks): - # Propagate exceptions and errors from tasks. - for task in tasks: - if task.done(): - # Ignore the result (None) of correctly finished tasks and re-raise - # exceptions and errors. - task.result() - # Agents are processed in parallel. # Events for each agent are put on queue sequentially. async def process_an_agent(events_for_one_agent): @@ -89,39 +68,34 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - tasks = [] - try: + async with asyncio.TaskGroup() as tg: for events_for_one_agent in agent_runs: - tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) + tg.create_task(process_an_agent(events_for_one_agent)) sentinel_count = 0 # Run until all agents finished processing. while sentinel_count < len(agent_runs): - propagate_exceptions(tasks) event, resume_signal = await queue.get() # Agent finished processing. if event is sentinel: sentinel_count += 1 else: yield event - # Signal to agent that event has been processed by runner and it can - # continue now. + # Signal to agent that it should generate next event. resume_signal.set() - finally: - for task in tasks: - task.cancel() -async def _merge_agent_run( +# TODO - remove once Python <3.11 is no longer supported. +async def _merge_agent_run_pre_3_11( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: - """Merges the agent run event generator. + """Merges agent runs for Python 3.10 without asyncio.TaskGroup. - This implementation guarantees for each agent, it won't move on until the - generated event is processed by upstream runner. + Uses custom cancellation and exception handling to mirror TaskGroup + semantics. Each agent waits until the runner processes emitted events. Args: - agent_runs: A list of async generators that yield events from each agent. + agent_runs: Async generators that yield events from each agent. Yields: Event: The next event from the merged generator. @@ -129,6 +103,14 @@ async def _merge_agent_run( sentinel = object() queue = asyncio.Queue() + def propagate_exceptions(tasks): + # Propagate exceptions and errors from tasks. + for task in tasks: + if task.done(): + # Ignore the result (None) of correctly finished tasks and re-raise + # exceptions and errors. + task.result() + # Agents are processed in parallel. # Events for each agent are put on queue sequentially. async def process_an_agent(events_for_one_agent): @@ -142,21 +124,27 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - async with asyncio.TaskGroup() as tg: + tasks = [] + try: for events_for_one_agent in agent_runs: - tg.create_task(process_an_agent(events_for_one_agent)) + tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) sentinel_count = 0 # Run until all agents finished processing. while sentinel_count < len(agent_runs): + propagate_exceptions(tasks) event, resume_signal = await queue.get() # Agent finished processing. if event is sentinel: sentinel_count += 1 else: yield event - # Signal to agent that it should generate next event. + # Signal to agent that event has been processed by runner and it can + # continue now. resume_signal.set() + finally: + for task in tasks: + task.cancel() class ParallelAgent(BaseAgent): @@ -195,13 +183,11 @@ async def _run_async_impl( pause_invocation = False try: - # TODO remove if once Python <3.11 is no longer supported. merge_func = ( _merge_agent_run if sys.version_info >= (3, 11) else _merge_agent_run_pre_3_11 ) - async with Aclosing(merge_func(agent_runs)) as agen: async for event in agen: yield event diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 5d42730937..8d133060ec 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -26,30 +26,22 @@ from urllib.parse import urlparse import uuid -try: - from a2a.client import Client as A2AClient - from a2a.client import ClientEvent as A2AClientEvent - from a2a.client.card_resolver import A2ACardResolver - from a2a.client.client import ClientConfig as A2AClientConfig - from a2a.client.client_factory import ClientFactory as A2AClientFactory - from a2a.client.errors import A2AClientError - from a2a.types import AgentCard - from a2a.types import Message as A2AMessage - from a2a.types import Part as A2APart - from a2a.types import Role - from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent - from a2a.types import TransportProtocol as A2ATransport -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python version." - ) from e - else: - raise e +from a2a.client import Client as A2AClient +from a2a.client import ClientEvent as A2AClientEvent +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.client import ClientConfig as A2AClientConfig +from a2a.client.client_factory import ClientFactory as A2AClientFactory +from a2a.client.errors import A2AClientError +from a2a.types import AgentCard +from a2a.types import Message as A2AMessage +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent +from a2a.types import TransportProtocol as A2ATransport +from google.genai import types as genai_types +import httpx try: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -57,9 +49,6 @@ # Fallback for older versions of a2a-sdk. AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" -from google.genai import types as genai_types -import httpx - from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event from ..a2a.converters.event_converter import convert_event_to_a2a_message diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index f9170968fd..df06b1cf4c 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -342,25 +342,14 @@ async def get_agent_builder( ) if a2a: - try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard - from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor - - except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e + from a2a.server.apps import A2AStarletteApplication + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryTaskStore + from a2a.types import AgentCard + from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + + from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor + # locate all a2a agent apps in the agents directory base_path = Path.cwd() / agents_dir # the root agents directory should be an existing folder diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index eaef479274..875b82e5b9 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -30,16 +30,9 @@ try: from crewai.tools import BaseTool as CrewaiBaseTool except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'Crewai Tools require Python 3.10+. Please upgrade your Python version.' - ) from e - else: - raise ImportError( - "Crewai Tools require pip install 'google-adk[extensions]'." - ) from e + raise ImportError( + "Crewai Tools require pip install 'google-adk[extensions]'." + ) from e class CrewaiTool(FunctionTool): diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py index f1e56b99c4..1170b2e1af 100644 --- a/src/google/adk/tools/mcp_tool/__init__.py +++ b/src/google/adk/tools/mcp_tool/__init__.py @@ -39,15 +39,7 @@ except ImportError as e: import logging - import sys logger = logging.getLogger('google_adk.' + __name__) - - if sys.version_info < (3, 10): - logger.warning( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) - else: - logger.debug('MCP Tool is not installed') - logger.debug(e) + logger.debug('MCP Tool is not installed') + logger.debug(e) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 7d9714aada..c9c4c2ae66 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -29,24 +29,13 @@ from typing import Union import anyio +from mcp import ClientSession +from mcp import StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client from pydantic import BaseModel -try: - from mcp import ClientSession - from mcp import StdioServerParameters - from mcp.client.sse import sse_client - from mcp.client.stdio import stdio_client - from mcp.client.streamable_http import streamablehttp_client -except ImportError as e: - - if sys.version_info < (3, 10): - raise ImportError( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) from e - else: - raise e - logger = logging.getLogger('google_adk.' + __name__) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 284aea4105..b15f2c73fe 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -17,7 +17,6 @@ import base64 import inspect import logging -import sys from typing import Any from typing import Callable from typing import Dict @@ -27,35 +26,21 @@ from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration +from mcp.types import Tool as McpBaseTool from typing_extensions import override from ...agents.readonly_context import ReadonlyContext -from ...features import FeatureName -from ...features import is_feature_enabled -from .._gemini_schema_util import _to_gemini_schema -from .mcp_session_manager import MCPSessionManager -from .mcp_session_manager import retry_on_errors - -# Attempt to import MCP Tool from the MCP library, and hints user to upgrade -# their Python version to 3.10 if it fails. -try: - from mcp.types import Tool as McpBaseTool -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "MCP Tool requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e - - from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme from ...auth.auth_tool import AuthConfig +from ...features import FeatureName +from ...features import is_feature_enabled +from .._gemini_schema_util import _to_gemini_schema from ..base_authenticated_tool import BaseAuthenticatedTool # import from ..tool_context import ToolContext +from .mcp_session_manager import MCPSessionManager +from .mcp_session_manager import retry_on_errors logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 3768477e1d..035b75878b 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -25,6 +25,8 @@ from typing import Union import warnings +from mcp import StdioServerParameters +from mcp.types import ListToolsResult from pydantic import model_validator from typing_extensions import override @@ -41,23 +43,6 @@ from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams - -# Attempt to import MCP Tool from the MCP library, and hints user to upgrade -# their Python version to 3.10 if it fails. -try: - from mcp import StdioServerParameters - from mcp.types import ListToolsResult -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "MCP Tool requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e - from .mcp_tool import MCPTool logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index bd8dacb9d8..a75feae3dd 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -20,30 +20,7 @@ from __future__ import annotations -from contextlib import AbstractAsyncContextManager -from typing import Any -from typing import AsyncGenerator +from contextlib import aclosing - -class Aclosing(AbstractAsyncContextManager): - """Async context manager for safely finalizing an asynchronously cleaned-up - resource such as an async generator, calling its ``aclose()`` method. - Needed to correctly close contexts for OTel spans. - See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100. - - Based on - https://docs.python.org/3/library/contextlib.html#contextlib.aclosing - which is available in Python 3.10+. - - TODO: replace all occurrences with contextlib.aclosing once Python 3.9 is no - longer supported. - """ - - def __init__(self, async_generator: AsyncGenerator[Any, None]): - self.async_generator = async_generator - - async def __aenter__(self): - return self.async_generator - - async def __aexit__(self, *exc_info): - await self.async_generator.aclose() +# Re-export aclosing for backward compatibility +Aclosing = aclosing diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index cb3f7a6858..49b7d3c2b6 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -12,50 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from google.adk.a2a.converters.event_converter import _create_artifact_id +from google.adk.a2a.converters.event_converter import _create_error_status_event +from google.adk.a2a.converters.event_converter import _create_status_update_event +from google.adk.a2a.converters.event_converter import _get_adk_metadata_key +from google.adk.a2a.converters.event_converter import _get_context_metadata +from google.adk.a2a.converters.event_converter import _process_long_running_tool +from google.adk.a2a.converters.event_converter import _serialize_metadata_value +from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR +from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event +from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events +from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message +from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import DataPart - from a2a.types import Message - from a2a.types import Role - from a2a.types import Task - from a2a.types import TaskState - from a2a.types import TaskStatusUpdateEvent - from google.adk.a2a.converters.event_converter import _create_artifact_id - from google.adk.a2a.converters.event_converter import _create_error_status_event - from google.adk.a2a.converters.event_converter import _create_status_update_event - from google.adk.a2a.converters.event_converter import _get_adk_metadata_key - from google.adk.a2a.converters.event_converter import _get_context_metadata - from google.adk.a2a.converters.event_converter import _process_long_running_tool - from google.adk.a2a.converters.event_converter import _serialize_metadata_value - from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR - from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message - from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE - from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX - from google.adk.agents.invocation_context import InvocationContext - from google.adk.events.event import Event - from google.adk.events.event_actions import EventActions -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestEventConverter: """Test suite for event_converter module.""" diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 5a8bad1096..541ab7709d 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -13,38 +13,21 @@ # limitations under the License. import json -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.genai import types as genai_types import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a import types as a2a_types - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY - from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part - from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part - from google.adk.a2a.converters.utils import _get_adk_metadata_key - from google.genai import types as genai_types -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestConvertA2aPartToGenaiPart: """Test cases for convert_a2a_part_to_genai_part function.""" diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index a7c21e4dbc..173b122d7c 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -12,33 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.server.agent_execution import RequestContext +from google.adk.a2a.converters.request_converter import _get_user_id +from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request +from google.adk.runners import RunConfig +from google.genai import types as genai_types import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.agent_execution import RequestContext - from google.adk.a2a.converters.request_converter import _get_user_id - from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request - from google.adk.runners import RunConfig - from google.genai import types as genai_types -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestGetUserId: """Test cases for _get_user_id function.""" diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py index 6c8511161a..0d896852aa 100644 --- a/tests/unittests/a2a/converters/test_utils.py +++ b/tests/unittests/a2a/converters/test_utils.py @@ -12,31 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys - +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.a2a.converters.utils import _from_a2a_context_id - from google.adk.a2a.converters.utils import _get_adk_metadata_key - from google.adk.a2a.converters.utils import _to_a2a_context_id - from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX - from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestUtilsFunctions: """Test suite for utils module functions.""" diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 4bcc7a91d7..58d7521f7d 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -12,41 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import TaskState +from a2a.types import TextPart +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.events.event import Event +from google.adk.runners import RunConfig +from google.adk.runners import Runner +from google.genai.types import Content import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.agent_execution.context import RequestContext - from a2a.server.events.event_queue import EventQueue - from a2a.types import Message - from a2a.types import TaskState - from a2a.types import TextPart - from google.adk.a2a.converters.request_converter import AgentRunRequest - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig - from google.adk.events.event import Event - from google.adk.runners import RunConfig - from google.adk.runners import Runner - from google.genai.types import Content -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestA2aAgentExecutor: """Test suite for A2aAgentExecutor class.""" diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py index 9d03db9dc8..b809b62728 100644 --- a/tests/unittests/a2a/executor/test_task_result_aggregator.py +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -12,35 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock +from a2a.types import Message +from a2a.types import Part +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import Message - from a2a.types import Part - from a2a.types import Role - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - def create_test_message(text: str): """Helper function to create a test Message object.""" diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py index e0b62468e5..3bf3202897 100644 --- a/tests/unittests/a2a/utils/test_agent_card_builder.py +++ b/tests/unittests/a2a/utils/test_agent_card_builder.py @@ -12,55 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentProvider +from a2a.types import AgentSkill +from a2a.types import SecurityScheme +from google.adk.a2a.utils.agent_card_builder import _build_agent_description +from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions +from google.adk.a2a.utils.agent_card_builder import _build_loop_description +from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill +from google.adk.a2a.utils.agent_card_builder import _build_parallel_description +from google.adk.a2a.utils.agent_card_builder import _build_sequential_description +from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples +from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction +from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name +from google.adk.a2a.utils.agent_card_builder import _get_agent_type +from google.adk.a2a.utils.agent_card_builder import _get_default_description +from google.adk.a2a.utils.agent_card_builder import _get_input_modes +from google.adk.a2a.utils.agent_card_builder import _get_output_modes +from google.adk.a2a.utils.agent_card_builder import _get_workflow_description +from google.adk.a2a.utils.agent_card_builder import _replace_pronouns +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.tools.example_tool import ExampleTool import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentProvider - from a2a.types import AgentSkill - from a2a.types import SecurityScheme - from google.adk.a2a.utils.agent_card_builder import _build_agent_description - from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions - from google.adk.a2a.utils.agent_card_builder import _build_loop_description - from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill - from google.adk.a2a.utils.agent_card_builder import _build_parallel_description - from google.adk.a2a.utils.agent_card_builder import _build_sequential_description - from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples - from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction - from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name - from google.adk.a2a.utils.agent_card_builder import _get_agent_type - from google.adk.a2a.utils.agent_card_builder import _get_default_description - from google.adk.a2a.utils.agent_card_builder import _get_input_modes - from google.adk.a2a.utils.agent_card_builder import _get_output_modes - from google.adk.a2a.utils.agent_card_builder import _get_workflow_description - from google.adk.a2a.utils.agent_card_builder import _replace_pronouns - from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder - from google.adk.agents.base_agent import BaseAgent - from google.adk.agents.llm_agent import LlmAgent - from google.adk.agents.loop_agent import LoopAgent - from google.adk.agents.parallel_agent import ParallelAgent - from google.adk.agents.sequential_agent import SequentialAgent - from google.adk.tools.example_tool import ExampleTool -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestAgentCardBuilder: """Test suite for AgentCardBuilder class.""" diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index ee80b0233b..503e572f2f 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -12,42 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from google.adk.a2a.utils.agent_to_a2a import to_a2a +from google.adk.agents.base_agent import BaseAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest - -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor - from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder - from google.adk.a2a.utils.agent_to_a2a import to_a2a - from google.adk.agents.base_agent import BaseAgent - from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService - from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - from google.adk.runners import Runner - from google.adk.sessions.in_memory_session_service import InMemorySessionService - from starlette.applications import Starlette -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e +from starlette.applications import Starlette class TestToA2A: diff --git a/tests/unittests/agents/test_mcp_instruction_provider.py b/tests/unittests/agents/test_mcp_instruction_provider.py index 1f2d098c2a..256d812630 100644 --- a/tests/unittests/agents/test_mcp_instruction_provider.py +++ b/tests/unittests/agents/test_mcp_instruction_provider.py @@ -13,34 +13,14 @@ # limitations under the License. """Unit tests for McpInstructionProvider.""" -import sys from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch +from google.adk.agents.mcp_instruction_provider import McpInstructionProvider from google.adk.agents.readonly_context import ReadonlyContext import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), - reason="MCP instruction provider requires Python 3.10+", -) - -# Import dependencies with version checking -try: - from google.adk.agents.mcp_instruction_provider import McpInstructionProvider -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - McpInstructionProvider = DummyClass - else: - raise e - class TestMcpInstructionProvider: """Unit tests for McpInstructionProvider.""" diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index fd722abf3f..e7865f39ba 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -14,70 +14,38 @@ import json from pathlib import Path -import sys import tempfile from unittest.mock import AsyncMock from unittest.mock import create_autospec from unittest.mock import Mock from unittest.mock import patch +from a2a.client.client import ClientConfig +from a2a.client.client import Consumer +from a2a.client.client_factory import ClientFactory +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentSkill +from a2a.types import Artifact +from a2a.types import Message as A2AMessage +from a2a.types import Part as A2ATaskStatus +from a2a.types import SendMessageSuccessResponse +from a2a.types import Task as A2ATask +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX +from google.adk.agents.remote_a2a_agent import AgentCardResolutionError +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.events.event import Event from google.adk.sessions.session import Session from google.genai import types as genai_types import httpx import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.client.client import ClientConfig - from a2a.client.client import Consumer - from a2a.client.client_factory import ClientFactory - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentSkill - from a2a.types import Artifact - from a2a.types import Message as A2AMessage - from a2a.types import Part as A2ATaskStatus - from a2a.types import SendMessageSuccessResponse - from a2a.types import Task as A2ATask - from a2a.types import TaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - from google.adk.agents.invocation_context import InvocationContext - from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX - from google.adk.agents.remote_a2a_agent import AgentCardResolutionError - from google.adk.agents.remote_a2a_agent import RemoteA2aAgent -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during module compilation. - # These are needed because the module has type annotations and module-level - # helper functions that reference imported types. - class DummyTypes: - pass - - AgentCapabilities = DummyTypes() - AgentCard = DummyTypes() - AgentSkill = DummyTypes() - A2AMessage = DummyTypes() - SendMessageSuccessResponse = DummyTypes() - A2ATask = DummyTypes() - TaskStatusUpdateEvent = DummyTypes() - Artifact = DummyTypes() - TaskArtifactUpdateEvent = DummyTypes() - InvocationContext = DummyTypes() - RemoteA2aAgent = DummyTypes() - AgentCardResolutionError = Exception - A2A_METADATA_PREFIX = "" - else: - raise e - # Helper function to create a proper AgentCard for testing def create_test_agent_card( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 1fe04732f5..75d5679084 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -509,8 +509,6 @@ async def create_test_eval_set( @pytest.fixture def temp_agents_dir_with_a2a(): """Create a temporary agents directory with A2A agent configurations for testing.""" - if sys.version_info < (3, 10): - pytest.skip("A2A requires Python 3.10+") with tempfile.TemporaryDirectory() as temp_dir: # Create test agent directory agent_dir = Path(temp_dir) / "test_a2a_agent" @@ -554,9 +552,6 @@ def test_app_with_a2a( temp_agents_dir_with_a2a, ): """Create a TestClient for the FastAPI app with A2A enabled.""" - if sys.version_info < (3, 10): - pytest.skip("A2A requires Python 3.10+") - # Mock A2A related classes with ( patch("signal.signal", return_value=None), @@ -1150,9 +1145,6 @@ def list_agents(self): assert "dotSrc" in response.json() -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) def test_a2a_agent_discovery(test_app_with_a2a): """Test that A2A agents are properly discovered and configured.""" # This test mainly verifies that the A2A setup doesn't break the app @@ -1161,9 +1153,6 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index cf2ca342f3..66080828d8 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -536,9 +536,6 @@ def test_generate_final_eval_status_doesn_t_throw_on(eval_service): @pytest.mark.asyncio -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) async def test_mcp_stdio_agent_no_runtime_error(mocker): """Test that LocalEvalService can handle MCP stdio agents without RuntimeError. diff --git a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py index 1e15f33899..2cf52e99cb 100644 --- a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py +++ b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py @@ -57,10 +57,8 @@ async def extract_error_from_result( return None -# Inheriting from IsolatedAsyncioTestCase ensures these tests works in Python -# 3.9. See https://github.com/pytest-dev/pytest-asyncio/issues/1039 -# Without this, the tests will fail with a "RuntimeError: There is no current -# event loop in thread 'MainThread'." +# Inheriting from IsolatedAsyncioTestCase ensures consistent behavior. +# See https://github.com/pytest-dev/pytest-asyncio/issues/1039 class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase): """Comprehensive tests for ReflectAndRetryToolPlugin focusing on behavior.""" diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py index 409571ad1f..43fe672333 100644 --- a/tests/unittests/telemetry/test_functional.py +++ b/tests/unittests/telemetry/test_functional.py @@ -103,7 +103,7 @@ def wrapped_firstiter(coro): isinstance(referrer, Aclosing) or isinstance(indirect_referrer, Aclosing) for referrer in gc.get_referrers(coro) - # Some coroutines have a layer of indirection in python 3.9 and 3.10 + # Some coroutines have a layer of indirection in Python 3.10 for indirect_referrer in gc.get_referrers(referrer) ), f'Coro `{coro.__name__}` is not wrapped with Aclosing' firstiter(coro) diff --git a/tests/unittests/tools/computer_use/test_computer_use_tool.py b/tests/unittests/tools/computer_use/test_computer_use_tool.py index 4dbdfbb5c0..f3843b87a6 100644 --- a/tests/unittests/tools/computer_use/test_computer_use_tool.py +++ b/tests/unittests/tools/computer_use/test_computer_use_tool.py @@ -47,7 +47,7 @@ async def tool_context(self): @pytest.fixture def mock_computer_function(self): """Fixture providing a mock computer function.""" - # Create a real async function instead of AsyncMock for Python 3.9 compatibility + # Create a real async function instead of AsyncMock for better test control calls = [] async def mock_func(*args, **kwargs): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index b2d6b1cb88..74eabe9d4d 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -22,46 +22,14 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from mcp import StdioServerParameters import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors - from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - MCPSessionManager = DummyClass - retry_on_errors = lambda x: x - SseConnectionParams = DummyClass - StdioConnectionParams = DummyClass - StreamableHTTPConnectionParams = DummyClass - else: - raise e - -# Import real MCP classes -try: - from mcp import StdioServerParameters -except ImportError: - # Create a mock if MCP is not available - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] - class MockClientSession: """Mock ClientSession for testing.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 17b1d8e54e..1284e73bce 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -23,39 +22,15 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +from google.genai.types import Type +from mcp.types import CallToolResult +from mcp.types import TextContent import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_tool import MCPTool - from google.adk.tools.tool_context import ToolContext - from google.genai.types import FunctionDeclaration - from google.genai.types import Type - from mcp.types import CallToolResult - from mcp.types import TextContent -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - MCPSessionManager = DummyClass - MCPTool = DummyClass - ToolContext = DummyClass - FunctionDeclaration = DummyClass - Type = DummyClass - CallToolResult = DummyClass - TextContent = DummyClass - else: - raise e - # Mock MCP Tool from mcp.types class MockMCPTool: diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 82a5c9a3e7..5809efe56f 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -20,47 +20,17 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.agents.readonly_context import ReadonlyContext from google.adk.auth.auth_credential import AuthCredential +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from mcp import StdioServerParameters import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.agents.readonly_context import ReadonlyContext - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams - from google.adk.tools.mcp_tool.mcp_tool import MCPTool - from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset - from mcp import StdioServerParameters -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] - - MCPSessionManager = DummyClass - SseConnectionParams = DummyClass - StdioConnectionParams = DummyClass - StreamableHTTPConnectionParams = DummyClass - MCPTool = DummyClass - MCPToolset = DummyClass - ReadonlyContext = DummyClass - else: - raise e - class MockMCPTool: """Mock MCP Tool for testing.""" diff --git a/tests/unittests/tools/retrieval/test_files_retrieval.py b/tests/unittests/tools/retrieval/test_files_retrieval.py index ea4b99cd98..dfb7215dce 100644 --- a/tests/unittests/tools/retrieval/test_files_retrieval.py +++ b/tests/unittests/tools/retrieval/test_files_retrieval.py @@ -14,7 +14,6 @@ """Tests for FilesRetrieval tool.""" -import sys import unittest.mock as mock from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model @@ -111,9 +110,6 @@ def mock_import(name, *args, **kwargs): def test_get_default_embedding_model_success(self): """Test _get_default_embedding_model returns Google embedding when available.""" - # Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available - if sys.version_info < (3, 10): - pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+") # Mock the module creation to avoid import issues mock_module = mock.MagicMock() diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py index a3a6598e35..7bfd912669 100644 --- a/tests/unittests/tools/test_mcp_toolset.py +++ b/tests/unittests/tools/test_mcp_toolset.py @@ -14,31 +14,12 @@ """Unit tests for McpToolset.""" -import sys from unittest.mock import AsyncMock from unittest.mock import MagicMock +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - McpToolset = DummyClass - else: - raise e - @pytest.mark.asyncio async def test_mcp_toolset_with_prefix(): From 9d918d45df4275b5b464e46817d2daaa03859fe3 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 3 Dec 2025 10:39:17 -0800 Subject: [PATCH 3/4] feat!: Rollback the DB migration as it is breaking Co-authored-by: Shangjie Chen PiperOrigin-RevId: 839818479 --- src/google/adk/cli/cli_tools_click.py | 36 -- .../adk/sessions/database_session_service.py | 222 +++++--- .../migrate_from_sqlalchemy_sqlite.py | 0 .../adk/sessions/migration/_schema_check.py | 114 ---- .../migrate_from_sqlalchemy_pickle.py | 492 ------------------ .../sessions/migration/migration_runner.py | 128 ----- .../adk/sessions/sqlite_session_service.py | 2 +- .../sessions/migration/test_migrations.py | 106 ---- .../sessions/test_dynamic_pickle_type.py | 181 +++++++ 9 files changed, 342 insertions(+), 939 deletions(-) rename src/google/adk/sessions/{migration => }/migrate_from_sqlalchemy_sqlite.py (100%) delete mode 100644 src/google/adk/sessions/migration/_schema_check.py delete mode 100644 src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py delete mode 100644 src/google/adk/sessions/migration/migration_runner.py delete mode 100644 tests/unittests/sessions/migration/test_migrations.py create mode 100644 tests/unittests/sessions/test_dynamic_pickle_type.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c4446278b4..5d228f72f3 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -36,7 +36,6 @@ from . import cli_deploy from .. import version from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE -from ..sessions.migration import migration_runner from .cli import run_cli from .fast_api import get_fast_api_app from .utils import envs @@ -1500,41 +1499,6 @@ def cli_deploy_cloud_run( click.secho(f"Deploy failed: {e}", fg="red", err=True) -@main.group() -def migrate(): - """Migrate ADK database schemas.""" - pass - - -@migrate.command("session", cls=HelpfulCommand) -@click.option( - "--source_db_url", - required=True, - help="SQLAlchemy URL of source database.", -) -@click.option( - "--dest_db_url", - required=True, - help="SQLAlchemy URL of destination database.", -) -@click.option( - "--log_level", - type=LOG_LEVELS, - default="INFO", - help="Optional. Set the logging level", -) -def cli_migrate_session( - *, source_db_url: str, dest_db_url: str, log_level: str -): - """Migrates a session database to the latest schema version.""" - logs.setup_adk_logger(getattr(logging, log_level.upper())) - try: - migration_runner.upgrade(source_db_url, dest_db_url) - click.secho("Migration check and upgrade process finished.", fg="green") - except Exception as e: - click.secho(f"Migration failed: {e}", fg="red", err=True) - - @deploy.command("agent_engine") @click.option( "--api_key", diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 1576151f23..a352918211 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -19,16 +19,18 @@ from datetime import timezone import json import logging +import pickle from typing import Any from typing import Optional import uuid +from google.genai import types +from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func -from sqlalchemy import inspect from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql @@ -39,11 +41,14 @@ from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime +from sqlalchemy.types import PickleType from sqlalchemy.types import String from sqlalchemy.types import TypeDecorator from typing_extensions import override @@ -52,10 +57,10 @@ from . import _session_util from ..errors.already_exists_error import AlreadyExistsError from ..events.event import Event +from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse -from .migration import _schema_check from .session import Session from .state import State @@ -106,20 +111,39 @@ def load_dialect_impl(self, dialect): return self.impl -class Base(DeclarativeBase): - """Base class for database tables.""" +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" - pass + impl = PickleType + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) + if dialect.name == "spanner+spanner": + from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType -class StorageMetadata(Base): - """Represents internal metadata stored in the database.""" + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.loads(value) + return value - __tablename__ = "adk_internal_metadata" - key: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + +class Base(DeclarativeBase): + """Base class for database tables.""" + + pass class StorageSession(Base): @@ -213,10 +237,46 @@ class StorageEvent(Base): ) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) + branch: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) timestamp: Mapped[PreciseTimestamp] = mapped_column( PreciseTimestamp, default=func.now() ) - event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON) + + # === Fields from llm_response.py === + content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + grounding_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) + error_code: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(String(1024), nullable=True) + interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) + input_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + output_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) storage_session: Mapped[StorageSession] = relationship( "StorageSession", @@ -231,27 +291,102 @@ class StorageEvent(Base): ), ) + @property + def long_running_tool_ids(self) -> set[str]: + return ( + set(json.loads(self.long_running_tool_ids_json)) + if self.long_running_tool_ids_json + else set() + ) + + @long_running_tool_ids.setter + def long_running_tool_ids(self, value: set[str]): + if value is None: + self.long_running_tool_ids_json = None + else: + self.long_running_tool_ids_json = json.dumps(list(value)) + @classmethod def from_event(cls, session: Session, event: Event) -> StorageEvent: - """Creates a StorageEvent from an Event.""" - return StorageEvent( + storage_event = StorageEvent( id=event.id, invocation_id=event.invocation_id, + author=event.author, + branch=event.branch, + actions=event.actions, session_id=session.id, app_name=session.app_name, user_id=session.user_id, timestamp=datetime.fromtimestamp(event.timestamp), - event_data=event.model_dump(exclude_none=True, mode="json"), + long_running_tool_ids=event.long_running_tool_ids, + partial=event.partial, + turn_complete=event.turn_complete, + error_code=event.error_code, + error_message=event.error_message, + interrupted=event.interrupted, ) + if event.content: + storage_event.content = event.content.model_dump( + exclude_none=True, mode="json" + ) + if event.grounding_metadata: + storage_event.grounding_metadata = event.grounding_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.custom_metadata: + storage_event.custom_metadata = event.custom_metadata + if event.usage_metadata: + storage_event.usage_metadata = event.usage_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.citation_metadata: + storage_event.citation_metadata = event.citation_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.input_transcription: + storage_event.input_transcription = event.input_transcription.model_dump( + exclude_none=True, mode="json" + ) + if event.output_transcription: + storage_event.output_transcription = ( + event.output_transcription.model_dump(exclude_none=True, mode="json") + ) + return storage_event def to_event(self) -> Event: - """Converts the StorageEvent to an Event.""" - return Event.model_validate({ - **self.event_data, - "id": self.id, - "invocation_id": self.invocation_id, - "timestamp": self.timestamp.timestamp(), - }) + return Event( + id=self.id, + invocation_id=self.invocation_id, + author=self.author, + branch=self.branch, + # This is needed as previous ADK version pickled actions might not have + # value defined in the current version of the EventActions model. + actions=EventActions().model_copy(update=self.actions.model_dump()), + timestamp=self.timestamp.timestamp(), + long_running_tool_ids=self.long_running_tool_ids, + partial=self.partial, + turn_complete=self.turn_complete, + error_code=self.error_code, + error_message=self.error_message, + interrupted=self.interrupted, + custom_metadata=self.custom_metadata, + content=_session_util.decode_model(self.content, types.Content), + grounding_metadata=_session_util.decode_model( + self.grounding_metadata, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + self.usage_metadata, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + self.citation_metadata, types.CitationMetadata + ), + input_transcription=_session_util.decode_model( + self.input_transcription, types.Transcription + ), + output_transcription=_session_util.decode_model( + self.output_transcription, types.Transcription + ), + ) class StorageAppState(Base): @@ -328,6 +463,7 @@ def __init__(self, db_url: str, **kwargs: Any): logger.info("Local timezone: %s", local_timezone) self.db_engine: AsyncEngine = db_engine + self.metadata: MetaData = MetaData() # DB session factory method self.database_session_factory: async_sessionmaker[ @@ -347,46 +483,10 @@ async def _ensure_tables_created(self): async with self._table_creation_lock: # Double-check after acquiring the lock if not self._tables_created: - # Check schema version BEFORE creating tables. - # This prevents creating metadata table on a v0.1 DB. - async with self.database_session_factory() as sql_session: - version, is_v01 = await sql_session.run_sync( - _schema_check.get_version_and_v01_status_sync - ) - - if is_v01: - raise RuntimeError( - "Database schema appears to be v0.1, but" - f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please" - " migrate the database using 'adk migrate session'." - ) - elif version and version < _schema_check.CURRENT_SCHEMA_VERSION: - raise RuntimeError( - f"Database schema version is {version}, but current version is" - f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate" - " the database to the latest version using 'adk migrate" - " session'." - ) - async with self.db_engine.begin() as conn: # Uncomment to recreate DB every time # await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) - - # If we are here, DB is either new or >= current version. - # If new or without metadata row, stamp it as current version. - async with self.database_session_factory() as sql_session: - metadata = await sql_session.get( - StorageMetadata, _schema_check.SCHEMA_VERSION_KEY - ) - if not metadata: - sql_session.add( - StorageMetadata( - key=_schema_check.SCHEMA_VERSION_KEY, - value=_schema_check.CURRENT_SCHEMA_VERSION, - ) - ) - await sql_session.commit() self._tables_created = True @override @@ -623,9 +723,7 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session.state = storage_session.state | session_state_delta if storage_session._dialect_name == "sqlite": - update_time = datetime.fromtimestamp( - event.timestamp, timezone.utc - ).replace(tzinfo=None) + update_time = datetime.utcfromtimestamp(event.timestamp) else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py similarity index 100% rename from src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py rename to src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py diff --git a/src/google/adk/sessions/migration/_schema_check.py b/src/google/adk/sessions/migration/_schema_check.py deleted file mode 100644 index f6fdc59956..0000000000 --- a/src/google/adk/sessions/migration/_schema_check.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Database schema version check utility.""" - -from __future__ import annotations - -import logging - -import sqlalchemy -from sqlalchemy import create_engine as create_sync_engine -from sqlalchemy import inspect -from sqlalchemy import text - -logger = logging.getLogger("google_adk." + __name__) - -SCHEMA_VERSION_KEY = "schema_version" -SCHEMA_VERSION_0_1_PICKLE = "0.1" -SCHEMA_VERSION_1_0_JSON = "1.0" -CURRENT_SCHEMA_VERSION = "1.0" - - -def _to_sync_url(db_url: str) -> str: - """Removes +driver from SQLAlchemy URL.""" - if "://" in db_url: - scheme, _, rest = db_url.partition("://") - if "+" in scheme: - dialect = scheme.split("+", 1)[0] - return f"{dialect}://{rest}" - return db_url - - -def get_version_and_v01_status_sync( - sess: sqlalchemy.orm.Session, -) -> tuple[str | None, bool]: - """Returns (version, is_v01) inspecting the database.""" - inspector = sqlalchemy.inspect(sess.get_bind()) - if inspector.has_table("adk_internal_metadata"): - try: - result = sess.execute( - text("SELECT value FROM adk_internal_metadata WHERE key = :key"), - {"key": SCHEMA_VERSION_KEY}, - ).fetchone() - # If table exists, with or without key, it's 1.0 or newer. - return (result[0] if result else SCHEMA_VERSION_1_0_JSON), False - except Exception as e: - logger.warning( - "Could not read from adk_internal_metadata: %s. Assuming v1.0.", - e, - ) - return SCHEMA_VERSION_1_0_JSON, False - - if inspector.has_table("events"): - try: - cols = {c["name"] for c in inspector.get_columns("events")} - if "actions" in cols and "event_data" not in cols: - return None, True # 0.1 schema - except Exception as e: - logger.warning("Could not inspect 'events' table columns: %s", e) - return None, False # New DB - - -def get_db_schema_version(db_url: str) -> str | None: - """Reads schema version from DB. - - Checks metadata table first, falls back to table structure for 0.1 vs 1.0. - """ - engine = None - try: - engine = create_sync_engine(_to_sync_url(db_url)) - inspector = inspect(engine) - - if inspector.has_table("adk_internal_metadata"): - with engine.connect() as connection: - result = connection.execute( - text("SELECT value FROM adk_internal_metadata WHERE key = :key"), - parameters={"key": SCHEMA_VERSION_KEY}, - ).fetchone() - # If table exists, with or without key, it's 1.0 or newer. - return result[0] if result else SCHEMA_VERSION_1_0_JSON - - # Metadata table doesn't exist, check for 0.1 schema. - # 0.1 schema has an 'events' table with an 'actions' column. - if inspector.has_table("events"): - try: - cols = {c["name"] for c in inspector.get_columns("events")} - if "actions" in cols and "event_data" not in cols: - return SCHEMA_VERSION_0_1_PICKLE - except Exception as e: - logger.warning("Could not inspect 'events' table columns: %s", e) - - # If no metadata table and not identifiable as 0.1, - # assume it is a new/empty DB requiring schema 1.0. - return SCHEMA_VERSION_1_0_JSON - except Exception as e: - logger.info( - "Could not determine schema version by inspecting database: %s." - " Assuming v1.0.", - e, - ) - return SCHEMA_VERSION_1_0_JSON - finally: - if engine: - engine.dispose() diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py deleted file mode 100644 index f33ef3f5cf..0000000000 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ /dev/null @@ -1,492 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema.""" - -from __future__ import annotations - -import argparse -from datetime import datetime -from datetime import timezone -import json -import logging -import pickle -import sys -from typing import Any -from typing import Optional - -from google.adk.events.event import Event -from google.adk.events.event_actions import EventActions -from google.adk.sessions import _session_util -from google.adk.sessions import database_session_service as dss -from google.adk.sessions.migration import _schema_check -from google.genai import types -import sqlalchemy -from sqlalchemy import Boolean -from sqlalchemy import create_engine -from sqlalchemy import ForeignKeyConstraint -from sqlalchemy import func -from sqlalchemy import text -from sqlalchemy import Text -from sqlalchemy.dialects import mysql -from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column -from sqlalchemy.orm import sessionmaker -from sqlalchemy.types import PickleType -from sqlalchemy.types import String -from sqlalchemy.types import TypeDecorator - -logger = logging.getLogger("google_adk." + __name__) - - -# --- Old Schema Definitions --- -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - -class OldBase(DeclarativeBase): - pass - - -class OldStorageSession(OldBase): - __tablename__ = "sessions" - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(dss.DynamicJSON), default={} - ) - create_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now() - ) - update_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -class OldStorageEvent(OldBase): - """Old storage event with pickle.""" - - __tablename__ = "events" - id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - session_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - invocation_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_VARCHAR_LENGTH) - ) - author: Mapped[str] = mapped_column(String(dss.DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[Any] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) - branch: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - timestamp: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now() - ) - content: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - custom_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - usage_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - citation_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - error_message: Mapped[str] = mapped_column(String(1024), nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - input_transcription: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - output_transcription: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - __table_args__ = ( - ForeignKeyConstraint( - ["app_name", "user_id", "session_id"], - ["sessions.app_name", "sessions.user_id", "sessions.id"], - ondelete="CASCADE", - ), - ) - - @property - def long_running_tool_ids(self) -> set[str]: - return ( - set(json.loads(self.long_running_tool_ids_json)) - if self.long_running_tool_ids_json - else set() - ) - - -def _to_datetime_obj(val: Any) -> datetime | Any: - """Converts string to datetime if needed.""" - if isinstance(val, str): - try: - return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f") - except ValueError: - try: - return datetime.strptime(val, "%Y-%m-%d %H:%M:%S") - except ValueError: - pass # return as is if not matching format - return val - - -def _row_to_event(row: dict) -> Event: - """Converts event row (dict) to event object, handling missing columns and deserializing.""" - - actions_val = row.get("actions") - actions = None - if actions_val is not None: - try: - if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) - else: # for spanner - it might return object directly - actions = actions_val - except Exception as e: - logger.warning( - f"Failed to unpickle actions for event {row.get('id')}: {e}" - ) - actions = None - - if actions and hasattr(actions, "model_dump"): - actions = EventActions().model_copy(update=actions.model_dump()) - elif isinstance(actions, dict): - actions = EventActions(**actions) - else: - actions = EventActions() - - def _safe_json_load(val): - data = None - if isinstance(val, str): - try: - data = json.loads(val) - except json.JSONDecodeError: - logger.warning(f"Failed to decode JSON for event {row.get('id')}") - return None - elif isinstance(val, dict): - data = val # for postgres JSONB - return data - - content_dict = _safe_json_load(row.get("content")) - grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata")) - custom_metadata_dict = _safe_json_load(row.get("custom_metadata")) - usage_metadata_dict = _safe_json_load(row.get("usage_metadata")) - citation_metadata_dict = _safe_json_load(row.get("citation_metadata")) - input_transcription_dict = _safe_json_load(row.get("input_transcription")) - output_transcription_dict = _safe_json_load(row.get("output_transcription")) - - long_running_tool_ids_json = row.get("long_running_tool_ids_json") - long_running_tool_ids = set() - if long_running_tool_ids_json: - try: - long_running_tool_ids = set(json.loads(long_running_tool_ids_json)) - except json.JSONDecodeError: - logger.warning( - "Failed to decode long_running_tool_ids_json for event" - f" {row.get('id')}" - ) - long_running_tool_ids = set() - - event_id = row.get("id") - if not event_id: - raise ValueError("Event must have an id.") - timestamp = _to_datetime_obj(row.get("timestamp")) - if not timestamp: - raise ValueError(f"Event {event_id} must have a timestamp.") - - return Event( - id=event_id, - invocation_id=row.get("invocation_id", ""), - author=row.get("author", "agent"), - branch=row.get("branch"), - actions=actions, - timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(), - long_running_tool_ids=long_running_tool_ids, - partial=row.get("partial"), - turn_complete=row.get("turn_complete"), - error_code=row.get("error_code"), - error_message=row.get("error_message"), - interrupted=row.get("interrupted"), - custom_metadata=custom_metadata_dict, - content=_session_util.decode_model(content_dict, types.Content), - grounding_metadata=_session_util.decode_model( - grounding_metadata_dict, types.GroundingMetadata - ), - usage_metadata=_session_util.decode_model( - usage_metadata_dict, types.GenerateContentResponseUsageMetadata - ), - citation_metadata=_session_util.decode_model( - citation_metadata_dict, types.CitationMetadata - ), - input_transcription=_session_util.decode_model( - input_transcription_dict, types.Transcription - ), - output_transcription=_session_util.decode_model( - output_transcription_dict, types.Transcription - ), - ) - - -def _get_state_dict(state_val: Any) -> dict: - """Safely load dict from JSON string or return dict if already dict.""" - if isinstance(state_val, dict): - return state_val - if isinstance(state_val, str): - try: - return json.loads(state_val) - except json.JSONDecodeError: - logger.warning( - "Failed to parse state JSON string, defaulting to empty dict." - ) - return {} - return {} - - -class OldStorageAppState(OldBase): - __tablename__ = "app_states" - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(dss.DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -class OldStorageUserState(OldBase): - __tablename__ = "user_states" - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(dss.DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -# --- Migration Logic --- -def migrate(source_db_url: str, dest_db_url: str): - """Migrates data from old pickle schema to new JSON schema.""" - logger.info(f"Connecting to source database: {source_db_url}") - try: - source_engine = create_engine(source_db_url) - SourceSession = sessionmaker(bind=source_engine) - except Exception as e: - logger.error(f"Failed to connect to source database: {e}") - raise RuntimeError(f"Failed to connect to source database: {e}") from e - - logger.info(f"Connecting to destination database: {dest_db_url}") - try: - dest_engine = create_engine(dest_db_url) - dss.Base.metadata.create_all(dest_engine) - DestSession = sessionmaker(bind=dest_engine) - except Exception as e: - logger.error(f"Failed to connect to destination database: {e}") - raise RuntimeError(f"Failed to connect to destination database: {e}") from e - - with SourceSession() as source_session, DestSession() as dest_session: - dest_session.merge( - dss.StorageMetadata( - key=_schema_check.SCHEMA_VERSION_KEY, - value=_schema_check.SCHEMA_VERSION_1_0_JSON, - ) - ) - dest_session.commit() - try: - inspector = sqlalchemy.inspect(source_engine) - - logger.info("Migrating app_states...") - if inspector.has_table("app_states"): - rows = ( - source_session.execute(text("SELECT * FROM app_states")) - .mappings() - .all() - ) - for row in rows: - dest_session.merge( - dss.StorageAppState( - app_name=row["app_name"], - state=_get_state_dict(row.get("state")), - update_time=_to_datetime_obj(row["update_time"]), - ) - ) - dest_session.commit() - logger.info(f"Migrated {len(rows)} app_states.") - else: - logger.info("No 'app_states' table found in source db.") - - logger.info("Migrating user_states...") - if inspector.has_table("user_states"): - rows = ( - source_session.execute(text("SELECT * FROM user_states")) - .mappings() - .all() - ) - for row in rows: - dest_session.merge( - dss.StorageUserState( - app_name=row["app_name"], - user_id=row["user_id"], - state=_get_state_dict(row.get("state")), - update_time=_to_datetime_obj(row["update_time"]), - ) - ) - dest_session.commit() - logger.info(f"Migrated {len(rows)} user_states.") - else: - logger.info("No 'user_states' table found in source db.") - - logger.info("Migrating sessions...") - if inspector.has_table("sessions"): - rows = ( - source_session.execute(text("SELECT * FROM sessions")) - .mappings() - .all() - ) - for row in rows: - dest_session.merge( - dss.StorageSession( - app_name=row["app_name"], - user_id=row["user_id"], - id=row["id"], - state=_get_state_dict(row.get("state")), - create_time=_to_datetime_obj(row["create_time"]), - update_time=_to_datetime_obj(row["update_time"]), - ) - ) - dest_session.commit() - logger.info(f"Migrated {len(rows)} sessions.") - else: - logger.info("No 'sessions' table found in source db.") - - logger.info("Migrating events...") - events = [] - if inspector.has_table("events"): - rows = ( - source_session.execute(text("SELECT * FROM events")) - .mappings() - .all() - ) - for row in rows: - try: - event_obj = _row_to_event(dict(row)) - new_event = dss.StorageEvent( - id=event_obj.id, - app_name=row["app_name"], - user_id=row["user_id"], - session_id=row["session_id"], - invocation_id=event_obj.invocation_id, - timestamp=datetime.fromtimestamp( - event_obj.timestamp, timezone.utc - ).replace(tzinfo=None), - event_data=event_obj.model_dump(mode="json", exclude_none=True), - ) - dest_session.merge(new_event) - events.append(new_event) - except Exception as e: - logger.warning( - f"Failed to migrate event row {row.get('id', 'N/A')}: {e}" - ) - dest_session.commit() - logger.info(f"Migrated {len(events)} events.") - else: - logger.info("No 'events' table found in source database.") - - logger.info("Migration completed successfully.") - except Exception as e: - logger.error(f"An error occurred during migration: {e}", exc_info=True) - dest_session.rollback() - raise RuntimeError(f"An error occurred during migration: {e}") from e - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=( - "Migrate ADK sessions from SQLAlchemy Pickle format to JSON format." - ) - ) - parser.add_argument( - "--source_db_url", required=True, help="SQLAlchemy URL of source database" - ) - parser.add_argument( - "--dest_db_url", - required=True, - help="SQLAlchemy URL of destination database", - ) - args = parser.parse_args() - try: - migrate(args.source_db_url, args.dest_db_url) - except Exception as e: - logger.error(f"Migration failed: {e}") - sys.exit(1) diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py deleted file mode 100644 index d7abbe41f9..0000000000 --- a/src/google/adk/sessions/migration/migration_runner.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Migration runner to upgrade schemas to the latest version.""" - -from __future__ import annotations - -import logging -import os -import tempfile - -from google.adk.sessions.migration import _schema_check -from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle - -logger = logging.getLogger("google_adk." + __name__) - -# Migration map where key is start_version and value is -# (end_version, migration_function). -# Each key is a schema version, and its value is a tuple containing: -# (the schema version AFTER this migration step, the migration function to run). -# The migration function should accept (source_db_url, dest_db_url) as -# arguments. -MIGRATIONS = { - _schema_check.SCHEMA_VERSION_0_1_PICKLE: ( - _schema_check.SCHEMA_VERSION_1_0_JSON, - migrate_from_sqlalchemy_pickle.migrate, - ), -} -# The most recent schema version. The migration process stops once this version -# is reached. -LATEST_VERSION = _schema_check.CURRENT_SCHEMA_VERSION - - -def upgrade(source_db_url: str, dest_db_url: str): - """Migrates a database from its current version to the latest version. - - If the source database schema is older than the latest version, this - function applies migration scripts sequentially until the schema reaches the - LATEST_VERSION. - - If multiple migration steps are required, intermediate results are stored in - temporary SQLite database files. This means a multi-step migration - between other database types (e.g. PostgreSQL to PostgreSQL) will use - SQLite for intermediate steps. - - In-place migration (source_db_url == dest_db_url) is not supported, - as migrations always read from a source and write to a destination. - - Args: - source_db_url: The SQLAlchemy URL of the database to migrate from. - dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be - different from source_db_url. - - Raises: - RuntimeError: If source_db_url and dest_db_url are the same, or if no - migration path is found. - """ - current_version = _schema_check.get_db_schema_version(source_db_url) - - if current_version == LATEST_VERSION: - logger.info( - f"Database {source_db_url} is already at latest version" - f" {LATEST_VERSION}. No migration needed." - ) - return - - if source_db_url == dest_db_url: - raise RuntimeError( - "In-place migration is not supported. " - "Please provide a different file for dest_db_url." - ) - - # Build the list of migration steps required to reach LATEST_VERSION. - migrations_to_run = [] - ver = current_version - while ver in MIGRATIONS and ver != LATEST_VERSION: - migrations_to_run.append(MIGRATIONS[ver]) - ver = MIGRATIONS[ver][0] - - if not migrations_to_run: - raise RuntimeError( - "Could not find migration path for schema version" - f" {current_version} to {LATEST_VERSION}." - ) - - temp_files = [] - in_url = source_db_url - try: - for i, (end_version, migrate_func) in enumerate(migrations_to_run): - is_last_step = i == len(migrations_to_run) - 1 - - if is_last_step: - out_url = dest_db_url - else: - # For intermediate steps, create a temporary SQLite DB to store the - # result. - fd, temp_path = tempfile.mkstemp(suffix=".db") - os.close(fd) - out_url = f"sqlite:///{temp_path}" - temp_files.append(temp_path) - logger.debug(f"Created temp db {out_url} for step {i+1}") - - logger.info( - f"Migrating from {in_url} to {out_url} (schema {end_version})..." - ) - migrate_func(in_url, out_url) - logger.info(f"Finished migration step to schema {end_version}.") - # The output of this step becomes the input for the next step. - in_url = out_url - finally: - # Ensure temporary files are cleaned up even if migration fails. - # Cleanup temp files - for path in temp_files: - try: - os.remove(path) - logger.debug(f"Removed temp db {path}") - except OSError as e: - logger.warning(f"Failed to remove temp db file {path}: {e}") diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index e0d44b3872..8ba6531f52 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -107,7 +107,7 @@ def __init__(self, db_path: str): f"Database {db_path} seems to use an old schema." " Please run the migration command to" " migrate it to the new schema. Example: `python -m" - " google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite" + " google.adk.sessions.migrate_from_sqlalchemy_sqlite" f" --source_db_path {db_path} --dest_db_path" f" {db_path}.new` then backup {db_path} and rename" f" {db_path}.new to {db_path}." diff --git a/tests/unittests/sessions/migration/test_migrations.py b/tests/unittests/sessions/migration/test_migrations.py deleted file mode 100644 index 938387d29b..0000000000 --- a/tests/unittests/sessions/migration/test_migrations.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for migration scripts.""" - -from __future__ import annotations - -from datetime import datetime -from datetime import timezone - -from google.adk.events.event_actions import EventActions -from google.adk.sessions import database_session_service as dss -from google.adk.sessions.migration import _schema_check -from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - - -def test_migrate_from_sqlalchemy_pickle(tmp_path): - """Tests for migrate_from_sqlalchemy_pickle.""" - source_db_path = tmp_path / "source_pickle.db" - dest_db_path = tmp_path / "dest_json.db" - source_db_url = f"sqlite:///{source_db_path}" - dest_db_url = f"sqlite:///{dest_db_path}" - - # Setup source DB with old pickle schema - source_engine = create_engine(source_db_url) - mfsp.OldBase.metadata.create_all(source_engine) - SourceSession = sessionmaker(bind=source_engine) - source_session = SourceSession() - - # Populate source data - now = datetime.now(timezone.utc) - app_state = mfsp.OldStorageAppState( - app_name="app1", state={"akey": 1}, update_time=now - ) - user_state = mfsp.OldStorageUserState( - app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now - ) - session = mfsp.OldStorageSession( - app_name="app1", - user_id="user1", - id="session1", - state={"skey": 3}, - create_time=now, - update_time=now, - ) - event = mfsp.OldStorageEvent( - id="event1", - app_name="app1", - user_id="user1", - session_id="session1", - invocation_id="invoke1", - author="user", - actions=EventActions(state_delta={"skey": 4}), - timestamp=now, - ) - source_session.add_all([app_state, user_state, session, event]) - source_session.commit() - source_session.close() - - mfsp.migrate(source_db_url, dest_db_url) - - # Verify destination DB - dest_engine = create_engine(dest_db_url) - DestSession = sessionmaker(bind=dest_engine) - dest_session = DestSession() - - metadata = dest_session.query(dss.StorageMetadata).first() - assert metadata is not None - assert metadata.key == _schema_check.SCHEMA_VERSION_KEY - assert metadata.value == _schema_check.SCHEMA_VERSION_1_0_JSON - - app_state_res = dest_session.query(dss.StorageAppState).first() - assert app_state_res is not None - assert app_state_res.app_name == "app1" - assert app_state_res.state == {"akey": 1} - - user_state_res = dest_session.query(dss.StorageUserState).first() - assert user_state_res is not None - assert user_state_res.user_id == "user1" - assert user_state_res.state == {"ukey": 2} - - session_res = dest_session.query(dss.StorageSession).first() - assert session_res is not None - assert session_res.id == "session1" - assert session_res.state == {"skey": 3} - - event_res = dest_session.query(dss.StorageEvent).first() - assert event_res is not None - assert event_res.id == "event1" - assert "state_delta" in event_res.event_data["actions"] - assert event_res.event_data["actions"]["state_delta"] == {"skey": 4} - - dest_session.close() diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py new file mode 100644 index 0000000000..e4eb084f88 --- /dev/null +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -0,0 +1,181 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from unittest import mock + +from google.adk.sessions.database_session_service import DynamicPickleType +import pytest +from sqlalchemy import create_engine +from sqlalchemy.dialects import mysql + + +@pytest.fixture +def pickle_type(): + """Fixture for DynamicPickleType instance.""" + return DynamicPickleType() + + +def test_load_dialect_impl_mysql(pickle_type): + """Test that MySQL dialect uses LONGBLOB.""" + # Mock the MySQL dialect + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + # Mock the return value of type_descriptor + mock_longblob_type = mock.Mock() + mock_dialect.type_descriptor.return_value = mock_longblob_type + + impl = pickle_type.load_dialect_impl(mock_dialect) + + # Verify type_descriptor was called once with mysql.LONGBLOB + mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB) + # Verify the return value is what we expect + assert impl == mock_longblob_type + + +def test_load_dialect_impl_spanner(pickle_type): + """Test that Spanner dialect uses SpannerPickleType.""" + # Mock the spanner dialect + mock_dialect = mock.Mock() + mock_dialect.name = "spanner+spanner" + + with mock.patch( + "google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType" + ) as mock_spanner_type: + pickle_type.load_dialect_impl(mock_dialect) + mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type) + + +def test_load_dialect_impl_default(pickle_type): + """Test that other dialects use default PickleType.""" + engine = create_engine("sqlite:///:memory:") + dialect = engine.dialect + impl = pickle_type.load_dialect_impl(dialect) + # Should return the default impl (PickleType) + assert impl == pickle_type.impl + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_bind_param_pickle_dialects(pickle_type, dialect_name): + """Test that MySQL and Spanner dialects pickle the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + test_data = {"key": "value", "nested": [1, 2, 3]} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should be pickled bytes + assert isinstance(result, bytes) + # Should be able to unpickle back to original + assert pickle.loads(result) == test_data + + +def test_process_bind_param_default(pickle_type): + """Test that other dialects return value as-is.""" + mock_dialect = mock.Mock() + mock_dialect.name = "sqlite" + + test_data = {"key": "value"} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should return value unchanged (SQLAlchemy's PickleType handles it) + assert result == test_data + + +def test_process_bind_param_none(pickle_type): + """Test that None values are handled correctly.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + result = pickle_type.process_bind_param(None, mock_dialect) + assert result is None + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_result_value_pickle_dialects(pickle_type, dialect_name): + """Test that MySQL and Spanner dialects unpickle the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + test_data = {"key": "value", "nested": [1, 2, 3]} + pickled_data = pickle.dumps(test_data) + + result = pickle_type.process_result_value(pickled_data, mock_dialect) + + # Should be unpickled back to original + assert result == test_data + + +def test_process_result_value_default(pickle_type): + """Test that other dialects return value as-is.""" + mock_dialect = mock.Mock() + mock_dialect.name = "sqlite" + + test_data = {"key": "value"} + result = pickle_type.process_result_value(test_data, mock_dialect) + + # Should return value unchanged (SQLAlchemy's PickleType handles it) + assert result == test_data + + +def test_process_result_value_none(pickle_type): + """Test that None values are handled correctly.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + result = pickle_type.process_result_value(None, mock_dialect) + assert result is None + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_roundtrip_pickle_dialects(pickle_type, dialect_name): + """Test full roundtrip for MySQL and Spanner: bind -> result.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + original_data = { + "string": "test", + "number": 42, + "list": [1, 2, 3], + "nested": {"a": 1, "b": 2}, + } + + # Simulate bind (Python -> DB) + bound_value = pickle_type.process_bind_param(original_data, mock_dialect) + assert isinstance(bound_value, bytes) + + # Simulate result (DB -> Python) + result_value = pickle_type.process_result_value(bound_value, mock_dialect) + assert result_value == original_data From 5947c41b554aca905e795b49aefc60b6c85be05f Mon Sep 17 00:00:00 2001 From: Bo Yang Date: Wed, 3 Dec 2025 13:40:21 -0800 Subject: [PATCH 4/4] chore: Update component owners Co-authored-by: Bo Yang PiperOrigin-RevId: 839896507 --- .../samples/adk_triaging_agent/agent.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index d3e653f1d0..19096ce8eb 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -26,20 +26,21 @@ import requests LABEL_TO_OWNER = { + "a2a": "seanzhou1023", "agent engine": "yeesian", - "documentation": "polong-lin", - "services": "DeanChensj", - "question": "", - "mcp": "seanzhou1023", - "tools": "seanzhou1023", + "auth": "seanzhou1023", + "bq": "shobsi", + "core": "Jacksunwei", + "documentation": "joefernandez", "eval": "ankursharmas", - "live": "hangfei", - "models": "genquan9", + "live": "seanzhou1023", + "mcp": "seanzhou1023", + "models": "xuanyang15", + "services": "DeanChensj", + "tools": "xuanyang15", "tracing": "jawoszek", - "core": "Jacksunwei", "web": "wyf7107", - "a2a": "seanzhou1023", - "bq": "shobsi", + "workflow": "DeanChensj", } LABEL_GUIDELINES = """ @@ -65,6 +66,8 @@ Agent Engine concepts, do not use this label—choose "core" instead. - "a2a": Agent-to-agent workflows, coordination logic, or A2A protocol. - "bq": BigQuery integration or general issues related to BigQuery. + - "workflow": Workflow agents and workflow execution. + - "auth": Authentication or authorization issues. When unsure between labels, prefer the most specific match. If a label cannot be assigned confidently, do not call the labeling tool. @@ -265,6 +268,8 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]: - If it's about Model Context Protocol (e.g. MCP tool, MCP toolset, MCP session management etc.), label it with both "mcp" and "tools". - If it's about A2A integrations or workflows, label it with "a2a". - If it's about BigQuery integrations, label it with "bq". + - If it's about workflow agents or workflow execution, label it with "workflow". + - If it's about authentication, label it with "auth". - If you can't find an appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:". ## Triaging Workflow