From d29261a3dc9c5a603feef27ea657c4a03bb8a089 Mon Sep 17 00:00:00 2001 From: Virtuoso633 Date: Tue, 25 Nov 2025 09:46:18 -0800 Subject: [PATCH 1/5] feat(models): Enable multi-provider support for Claude and LiteLLM Merges: https://github.com/google/adk-python/pull/2810 Co-authored-by: Xuan Yang PiperOrigin-RevId: 836706608 --- src/google/adk/models/__init__.py | 20 +++++++ src/google/adk/models/lite_llm.py | 14 ++++- src/google/adk/models/registry.py | 24 +++++++- .../unittests/agents/test_llm_agent_fields.py | 47 +++++++++++++++ tests/unittests/models/test_models.py | 58 ++++++++++++++++++- 5 files changed, 156 insertions(+), 7 deletions(-) diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 9f3c2a2c48..1be0cc698e 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -33,3 +33,23 @@ LLMRegistry.register(Gemini) LLMRegistry.register(Gemma) LLMRegistry.register(ApigeeLlm) + +# Optionally register Claude if anthropic package is installed +try: + from .anthropic_llm import Claude + + LLMRegistry.register(Claude) + __all__.append('Claude') +except Exception: + # Claude support requires: pip install google-adk[extensions] + pass + +# Optionally register LiteLlm if litellm package is installed +try: + from .lite_llm import LiteLlm + + LLMRegistry.register(LiteLlm) + __all__.append('LiteLlm') +except Exception: + # LiteLLM support requires: pip install google-adk[extensions] + pass diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 9e3698b190..162db05945 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1388,11 +1388,19 @@ async def generate_content_async( def supported_models(cls) -> list[str]: """Provides the list of supported models. - LiteLlm supports all models supported by litellm. We do not keep track of - these models here. So we return an empty list. + This registers common provider prefixes. LiteLlm can handle many more, + but these patterns activate the integration for the most common use cases. + See https://docs.litellm.ai/docs/providers for a full list. Returns: A list of supported models. """ - return [] + return [ + # For OpenAI models (e.g., "openai/gpt-4o") + r"openai/.*", + # For Groq models via Groq API (e.g., "groq/llama3-70b-8192") + r"groq/.*", + # For Anthropic models (e.g., "anthropic/claude-3-opus-20240229") + r"anthropic/.*", + ] diff --git a/src/google/adk/models/registry.py b/src/google/adk/models/registry.py index 22e24d4c18..852996ff40 100644 --- a/src/google/adk/models/registry.py +++ b/src/google/adk/models/registry.py @@ -99,4 +99,26 @@ def resolve(model: str) -> type[BaseLlm]: if re.compile(regex).fullmatch(model): return llm_class - raise ValueError(f'Model {model} not found.') + # Provide helpful error messages for known patterns + error_msg = f'Model {model} not found.' + + # Check if it matches known patterns that require optional dependencies + if re.match(r'^claude-', model): + error_msg += ( + '\n\nClaude models require the anthropic package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install anthropic>=0.43.0' + ) + elif '/' in model: + # Any model with provider/model format likely needs LiteLLM + error_msg += ( + '\n\nProvider-style models (e.g., "provider/model-name") require' + ' the litellm package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install litellm>=1.75.5' + '\n\nSupported providers include: openai, groq, anthropic, and 100+' + ' others.' + '\nSee https://docs.litellm.ai/docs/providers for a full list.' + ) + + raise ValueError(error_msg) diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index c57254dbc8..577923f7bf 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -22,6 +22,9 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.anthropic_llm import Claude +from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.models.llm_request import LlmRequest from google.adk.models.registry import LLMRegistry from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -411,3 +414,47 @@ async def test_handle_vais_only(self): assert len(tools) == 1 assert tools[0].name == 'vertex_ai_search' assert tools[0].__class__.__name__ == 'VertexAiSearchTool' + + +# Tests for multi-provider model support via string model names +@pytest.mark.parametrize( + 'model_name', + [ + 'gemini-1.5-flash', + 'gemini-2.0-flash-exp', + ], +) +def test_agent_with_gemini_string_model(model_name): + """Test that Agent accepts Gemini model strings and resolves to Gemini.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Gemini) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-sonnet-4@20250514', + ], +) +def test_agent_with_claude_string_model(model_name): + """Test that Agent accepts Claude model strings and resolves to Claude.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Claude) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'groq/llama3-70b-8192', + 'anthropic/claude-3-opus-20240229', + ], +) +def test_agent_with_litellm_string_model(model_name): + """Test that Agent accepts LiteLLM provider strings.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, LiteLlm) + assert agent.canonical_model.model == model_name diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index 70246c7bc1..8575064baa 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -15,7 +15,7 @@ from google.adk import models from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini -from google.adk.models.registry import LLMRegistry +from google.adk.models.lite_llm import LiteLlm import pytest @@ -34,6 +34,7 @@ ], ) def test_match_gemini_family(model_name): + """Test that Gemini models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Gemini @@ -51,12 +52,63 @@ def test_match_gemini_family(model_name): ], ) def test_match_claude_family(model_name): - LLMRegistry.register(Claude) - + """Test that Claude models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Claude +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'openai/gpt-4o-mini', + 'groq/llama3-70b-8192', + 'groq/mixtral-8x7b-32768', + 'anthropic/claude-3-opus-20240229', + 'anthropic/claude-3-5-sonnet-20241022', + ], +) +def test_match_litellm_family(model_name): + """Test that LiteLLM models are resolved correctly.""" + assert models.LLMRegistry.resolve(model_name) is LiteLlm + + def test_non_exist_model(): with pytest.raises(ValueError) as e_info: models.LLMRegistry.resolve('non-exist-model') assert 'Model non-exist-model not found.' in str(e_info.value) + + +def test_helpful_error_for_claude_without_extensions(): + """Test that missing Claude models show helpful install instructions. + + Note: This test may pass even when anthropic IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent Claude model variant to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('claude-nonexistent-model-xyz') + + error_msg = str(e_info.value) + # The error should mention anthropic package and installation instructions + # These checks work whether or not anthropic is actually installed + assert 'Model claude-nonexistent-model-xyz not found' in error_msg + assert 'anthropic package' in error_msg + assert 'pip install' in error_msg + + +def test_helpful_error_for_litellm_without_extensions(): + """Test that missing LiteLLM models show helpful install instructions. + + Note: This test may pass even when litellm IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent provider to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('unknown-provider/gpt-4o') + + error_msg = str(e_info.value) + # The error should mention litellm package for provider-style models + assert 'Model unknown-provider/gpt-4o not found' in error_msg + assert 'litellm package' in error_msg + assert 'pip install' in error_msg + assert 'Provider-style models' in error_msg From 5cad8a7f58b36ca8ae0e5db2d0a8fb8718d330fd Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Tue, 25 Nov 2025 10:04:23 -0800 Subject: [PATCH 2/5] fix: Throw warning when using transparent session resumption in ADK Live for Gemini API key transparent session resumption is only supported in Vertex AI APIs Co-authored-by: Hangfei Lin PiperOrigin-RevId: 836715170 --- src/google/adk/models/gemini_llm_connection.py | 2 +- src/google/adk/models/google_llm.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 0b72c79f83..15e6ed9599 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -244,7 +244,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ] yield LlmResponse(content=types.Content(role='model', parts=parts)) if message.session_resumption_update: - logger.info('Received session resumption message: %s', message) + logger.debug('Received session resumption message: %s', message) yield ( LlmResponse( live_session_resumption_update=message.session_resumption_update diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 1bdd311104..90c2fece76 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -325,6 +325,23 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: types.Part.from_text(text=llm_request.config.system_instruction) ], ) + if ( + llm_request.live_connect_config.session_resumption + and llm_request.live_connect_config.session_resumption.transparent + ): + logger.debug( + 'session resumption config: %s', + llm_request.live_connect_config.session_resumption, + ) + logger.debug( + 'self._api_backend: %s', + self._api_backend, + ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: + raise ValueError( + 'Transparent session resumption is only supported for Vertex AI' + ' backend. Please use Vertex AI backend.' + ) llm_request.live_connect_config.tools = llm_request.config.tools logger.info('Connecting to live for model: %s', llm_request.model) logger.debug('Connecting to live with llm_request:%s', llm_request) From 5453b5bfdedc91d9d668c9eac39e3bb009a7bbbf Mon Sep 17 00:00:00 2001 From: happyryan Date: Tue, 25 Nov 2025 10:29:38 -0800 Subject: [PATCH 3/5] fix: Allow image parts in user messages for Anthropic Claude Previously, image parts were always filtered out when converting content to Anthropic message parameters. This change updates the logic to only filter out image parts and log a warning when the content role is not "user". This enables sending image data as part of user prompts to Claude models Merges: https://github.com/google/adk-python/pull/3286 Co-authored-by: George Weale PiperOrigin-RevId: 836725196 --- src/google/adk/models/anthropic_llm.py | 8 +- tests/unittests/models/test_anthropic_llm.py | 79 ++++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 6f343367a3..f965a9906d 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -155,9 +155,11 @@ def content_to_message_param( ) -> anthropic_types.MessageParam: message_block = [] for part in content.parts or []: - # Image data is not supported in Claude for model turns. - if _is_image_part(part): - logger.warning("Image data is not supported in Claude for model turns.") + # Image data is not supported in Claude for assistant turns. + if content.role != "user" and _is_image_part(part): + logger.warning( + "Image data is not supported in Claude for assistant turns." + ) continue message_block.append(part_to_message_block(part)) diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index e5ac8cc051..13d615bc32 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -20,6 +20,7 @@ from google.adk import version as adk_version from google.adk.models import anthropic_llm from google.adk.models.anthropic_llm import Claude +from google.adk.models.anthropic_llm import content_to_message_param from google.adk.models.anthropic_llm import function_declaration_to_tool_param from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse @@ -462,3 +463,81 @@ def test_part_to_message_block_with_multiple_content_items(): assert isinstance(result, dict) # Multiple text items should be joined with newlines assert result["content"] == "First part\nSecond part" + + +content_to_message_param_test_cases = [ + ( + "user_role_with_text_and_image", + Content( + role="user", + parts=[ + Part.from_text(text="What's in this image?"), + Part( + inline_data=types.Blob( + mime_type="image/jpeg", data=b"fake_image_data" + ) + ), + ], + ), + "user", + 2, # Expected content length + False, # Should not log warning + ), + ( + "model_role_with_text_and_image", + Content( + role="model", + parts=[ + Part.from_text(text="I see a cat."), + Part( + inline_data=types.Blob( + mime_type="image/png", data=b"fake_image_data" + ) + ), + ], + ), + "assistant", + 1, # Image filtered out, only text remains + True, # Should log warning + ), + ( + "assistant_role_with_text_and_image", + Content( + role="assistant", + parts=[ + Part.from_text(text="Here's what I found."), + Part( + inline_data=types.Blob( + mime_type="image/webp", data=b"fake_image_data" + ) + ), + ], + ), + "assistant", + 1, # Image filtered out, only text remains + True, # Should log warning + ), +] + + +@pytest.mark.parametrize( + "_, content, expected_role, expected_content_length, should_log_warning", + content_to_message_param_test_cases, + ids=[case[0] for case in content_to_message_param_test_cases], +) +def test_content_to_message_param_with_images( + _, content, expected_role, expected_content_length, should_log_warning +): + """Test content_to_message_param handles images correctly based on role.""" + with mock.patch("google.adk.models.anthropic_llm.logger") as mock_logger: + result = content_to_message_param(content) + + assert result["role"] == expected_role + assert len(result["content"]) == expected_content_length + + if should_log_warning: + mock_logger.warning.assert_called_once_with( + "Image data is not supported in Claude for assistant turns." + ) + else: + mock_logger.warning.assert_not_called() From 06e6fc91327a8bcea1bdc72f8eee94ee05cbbb91 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 25 Nov 2025 10:47:17 -0800 Subject: [PATCH 4/5] feat: wire runtime entrypoints to service factory defaults This change routes adk run and the FastAPI server through the new session/artifact service factory, keeps the default experience backed by per-agent .adk storage Co-authored-by: George Weale PiperOrigin-RevId: 836733234 --- src/google/adk/cli/cli.py | 79 ++++++--- src/google/adk/cli/fast_api.py | 63 +++---- src/google/adk/cli/utils/service_factory.py | 138 +++++++++++++++ tests/unittests/cli/test_fast_api.py | 12 +- tests/unittests/cli/utils/test_cli.py | 131 ++++++++++++-- .../cli/utils/test_service_factory.py | 162 ++++++++++++++++++ 6 files changed, 508 insertions(+), 77 deletions(-) create mode 100644 src/google/adk/cli/utils/service_factory.py create mode 100644 tests/unittests/cli/utils/test_service_factory.py diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 5ae18aac0a..af57a687fb 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -15,6 +15,7 @@ from __future__ import annotations from datetime import datetime +from pathlib import Path from typing import Optional from typing import Union @@ -22,7 +23,6 @@ from google.genai import types from pydantic import BaseModel -from ..agents.base_agent import BaseAgent from ..agents.llm_agent import LlmAgent from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService @@ -35,8 +35,11 @@ from ..sessions.session import Session from ..utils.context_utils import Aclosing from ..utils.env_utils import is_env_enabled +from .service_registry import load_services_module from .utils import envs from .utils.agent_loader import AgentLoader +from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_session_service_from_options class InputFile(BaseModel): @@ -66,7 +69,7 @@ async def run_input_file( ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) - input_file.state['_time'] = datetime.now() + input_file.state['_time'] = datetime.now().isoformat() session = await session_service.create_session( app_name=app_name, user_id=user_id, state=input_file.state @@ -134,6 +137,8 @@ async def run_cli( saved_session_file: Optional[str] = None, save_session: bool, session_id: Optional[str] = None, + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, ) -> None: """Runs an interactive CLI for a certain agent. @@ -148,24 +153,47 @@ async def run_cli( contains a previously saved session, exclusive with input_file. save_session: bool, whether to save the session on exit. session_id: Optional[str], the session ID to save the session to on exit. + session_service_uri: Optional[str], custom session service URI. + artifact_service_uri: Optional[str], custom artifact service URI. """ + agent_parent_path = Path(agent_parent_dir).resolve() + agent_root = agent_parent_path / agent_folder_name + load_services_module(str(agent_root)) + user_id = 'test_user' - artifact_service = InMemoryArtifactService() - session_service = InMemorySessionService() - credential_service = InMemoryCredentialService() + # Create session and artifact services using factory functions + session_service = create_session_service_from_options( + base_dir=agent_root, + session_service_uri=session_service_uri, + ) - user_id = 'test_user' - agent_or_app = AgentLoader(agents_dir=agent_parent_dir).load_agent( + artifact_service = create_artifact_service_from_options( + base_dir=agent_root, + artifact_service_uri=artifact_service_uri, + ) + + credential_service = InMemoryCredentialService() + agents_dir = str(agent_parent_path) + agent_or_app = AgentLoader(agents_dir=agents_dir).load_agent( agent_folder_name ) session_app_name = ( agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name ) - session = await session_service.create_session( - app_name=session_app_name, user_id=user_id - ) if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): - envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) + envs.load_dotenv_for_agent(agent_folder_name, agents_dir) + + # Helper function for printing events + def _print_event(event) -> None: + content = event.content + if not content or not content.parts: + return + text_parts = [part.text for part in content.parts if part.text] + if not text_parts: + return + author = event.author or 'system' + click.echo(f'[{author}]: {"".join(text_parts)}') + if input_file: session = await run_input_file( app_name=session_app_name, @@ -177,16 +205,22 @@ async def run_cli( input_path=input_file, ) elif saved_session_file: + # Load the saved session from file with open(saved_session_file, 'r', encoding='utf-8') as f: loaded_session = Session.model_validate_json(f.read()) + # Create a new session in the service, copying state from the file + session = await session_service.create_session( + app_name=session_app_name, + user_id=user_id, + state=loaded_session.state if loaded_session else None, + ) + + # Append events from the file to the new session and display them if loaded_session: for event in loaded_session.events: await session_service.append_event(session, event) - content = event.content - if not content or not content.parts or not content.parts[0].text: - continue - click.echo(f'[{event.author}]: {content.parts[0].text}') + _print_event(event) await run_interactively( agent_or_app, @@ -196,6 +230,9 @@ async def run_cli( credential_service, ) else: + session = await session_service.create_session( + app_name=session_app_name, user_id=user_id + ) click.echo(f'Running agent {agent_or_app.name}, type exit to exit.') await run_interactively( agent_or_app, @@ -207,9 +244,7 @@ async def run_cli( if save_session: session_id = session_id or input('Session ID to save: ') - session_path = ( - f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json' - ) + session_path = agent_root / f'{session_id}.session.json' # Fetch the session again to get all the details. session = await session_service.get_session( @@ -217,9 +252,9 @@ async def run_cli( user_id=session.user_id, session_id=session.id, ) - with open(session_path, 'w', encoding='utf-8') as f: - f.write( - session.model_dump_json(indent=2, exclude_none=True, by_alias=True) - ) + session_path.write_text( + session.model_dump_json(indent=2, exclude_none=True, by_alias=True), + encoding='utf-8', + ) print('Session saved to', session_path) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index eec6bb646b..86c7ca55c6 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -34,20 +34,19 @@ from starlette.types import Lifespan from watchdog.observers import Observer -from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager -from ..memory.in_memory_memory_service import InMemoryMemoryService from ..runners import Runner -from ..sessions.in_memory_session_service import InMemorySessionService from .adk_web_server import AdkWebServer -from .service_registry import get_service_registry from .service_registry import load_services_module from .utils import envs from .utils import evals from .utils.agent_change_handler import AgentChangeEventHandler from .utils.agent_loader import AgentLoader +from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_memory_service_from_options +from .utils.service_factory import create_session_service_from_options logger = logging.getLogger("google_adk." + __name__) @@ -74,6 +73,8 @@ def get_fast_api_app( logo_text: Optional[str] = None, logo_image_url: Optional[str] = None, ) -> FastAPI: + # Convert to absolute path for consistency + agents_dir = str(Path(agents_dir).resolve()) # Set up eval managers. if eval_storage_uri: @@ -91,48 +92,32 @@ def get_fast_api_app( # Load services.py from agents_dir for custom service registration. load_services_module(agents_dir) - service_registry = get_service_registry() - # Build the Memory service - if memory_service_uri: - memory_service = service_registry.create_memory_service( - memory_service_uri, agents_dir=agents_dir + try: + memory_service = create_memory_service_from_options( + base_dir=agents_dir, + memory_service_uri=memory_service_uri, ) - if not memory_service: - raise click.ClickException( - "Unsupported memory service URI: %s" % memory_service_uri - ) - else: - memory_service = InMemoryMemoryService() + except ValueError as exc: + raise click.ClickException(str(exc)) from exc # Build the Session service - if session_service_uri: - session_kwargs = session_db_kwargs or {} - session_service = service_registry.create_session_service( - session_service_uri, agents_dir=agents_dir, **session_kwargs - ) - if not session_service: - # Fallback to DatabaseSessionService if the service registry doesn't - # support the session service URI scheme. - from ..sessions.database_session_service import DatabaseSessionService - - session_service = DatabaseSessionService( - db_url=session_service_uri, **session_kwargs - ) - else: - session_service = InMemorySessionService() + session_service = create_session_service_from_options( + base_dir=agents_dir, + session_service_uri=session_service_uri, + session_db_kwargs=session_db_kwargs, + per_agent=True, # Multi-agent mode + ) # Build the Artifact service - if artifact_service_uri: - artifact_service = service_registry.create_artifact_service( - artifact_service_uri, agents_dir=agents_dir + try: + artifact_service = create_artifact_service_from_options( + base_dir=agents_dir, + artifact_service_uri=artifact_service_uri, + per_agent=True, # Multi-agent mode ) - if not artifact_service: - raise click.ClickException( - "Unsupported artifact service URI: %s" % artifact_service_uri - ) - else: - artifact_service = InMemoryArtifactService() + except ValueError as exc: + raise click.ClickException(str(exc)) from exc # Build the Credential service credential_service = InMemoryCredentialService() diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py new file mode 100644 index 0000000000..50064f4b8f --- /dev/null +++ b/src/google/adk/cli/utils/service_factory.py @@ -0,0 +1,138 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any +from typing import Optional + +from ...artifacts.base_artifact_service import BaseArtifactService +from ...memory.base_memory_service import BaseMemoryService +from ...sessions.base_session_service import BaseSessionService +from ..service_registry import get_service_registry +from .local_storage import create_local_artifact_service + +logger = logging.getLogger("google_adk." + __name__) + + +def create_session_service_from_options( + *, + base_dir: Path | str, + session_service_uri: Optional[str] = None, + session_db_kwargs: Optional[dict[str, Any]] = None, + per_agent: bool = False, +) -> BaseSessionService: + """Creates a session service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + kwargs: dict[str, Any] = { + "agents_dir": str(base_path), + "per_agent": per_agent, + } + if session_db_kwargs: + kwargs.update(session_db_kwargs) + + if session_service_uri: + if per_agent: + logger.warning( + "per_agent is not supported with remote session service URIs," + " ignoring" + ) + logger.info("Using session service URI: %s", session_service_uri) + service = registry.create_session_service(session_service_uri, **kwargs) + if service is not None: + return service + + # Fallback to DatabaseSessionService if the registry doesn't support the + # session service URI scheme. This keeps support for SQLAlchemy-compatible + # databases like AlloyDB or Cloud Spanner without explicit registration. + from ...sessions.database_session_service import DatabaseSessionService + + fallback_kwargs = dict(kwargs) + fallback_kwargs.pop("agents_dir", None) + fallback_kwargs.pop("per_agent", None) + logger.info( + "Falling back to DatabaseSessionService for URI: %s", + session_service_uri, + ) + return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) + + logger.info("Using in-memory session service") + from ...sessions.in_memory_session_service import InMemorySessionService + + return InMemorySessionService() + + +def create_memory_service_from_options( + *, + base_dir: Path | str, + memory_service_uri: Optional[str] = None, +) -> BaseMemoryService: + """Creates a memory service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + if memory_service_uri: + logger.info("Using memory service URI: %s", memory_service_uri) + service = registry.create_memory_service( + memory_service_uri, + agents_dir=str(base_path), + ) + if service is None: + raise ValueError(f"Unsupported memory service URI: {memory_service_uri}") + return service + + logger.info("Using in-memory memory service") + from ...memory.in_memory_memory_service import InMemoryMemoryService + + return InMemoryMemoryService() + + +def create_artifact_service_from_options( + *, + base_dir: Path | str, + artifact_service_uri: Optional[str] = None, + per_agent: bool = False, +) -> BaseArtifactService: + """Creates an artifact service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + if artifact_service_uri: + if per_agent: + logger.warning( + "per_agent is not supported with remote artifact service URIs," + " ignoring" + ) + logger.info("Using artifact service URI: %s", artifact_service_uri) + service = registry.create_artifact_service( + artifact_service_uri, + agents_dir=str(base_path), + per_agent=per_agent, + ) + if service is None: + logger.warning( + "Unsupported artifact service URI: %s, falling back to in-memory", + artifact_service_uri, + ) + from ...artifacts.in_memory_artifact_service import InMemoryArtifactService + + return InMemoryArtifactService() + return service + + if per_agent: + logger.info("Using shared file artifact service rooted at %s", base_dir) + return create_local_artifact_service(base_dir=base_path, per_agent=per_agent) diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d50bfcd8e5..a8b1ef2f2f 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -327,15 +327,15 @@ def test_app( with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.InMemorySessionService", + "google.adk.cli.fast_api.create_session_service_from_options", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.InMemoryArtifactService", + "google.adk.cli.fast_api.create_artifact_service_from_options", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.InMemoryMemoryService", + "google.adk.cli.fast_api.create_memory_service_from_options", return_value=mock_memory_service, ), patch( @@ -472,15 +472,15 @@ def test_app_with_a2a( with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.InMemorySessionService", + "google.adk.cli.fast_api.create_session_service_from_options", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.InMemoryArtifactService", + "google.adk.cli.fast_api.create_artifact_service_from_options", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.InMemoryMemoryService", + "google.adk.cli.fast_api.create_memory_service_from_options", return_value=mock_memory_service, ), patch( diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 0de59598b3..33ddbf495c 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -28,7 +28,12 @@ import click from google.adk.agents.base_agent import BaseAgent from google.adk.apps.app import App +from google.adk.artifacts.file_artifact_service import FileArtifactService +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService import google.adk.cli.cli as cli +from google.adk.cli.utils.service_factory import create_artifact_service_from_options +from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest @@ -151,9 +156,9 @@ def _echo(msg: str) -> None: input_path = tmp_path / "input.json" input_path.write_text(json.dumps(input_json)) - artifact_service = cli.InMemoryArtifactService() - session_service = cli.InMemorySessionService() - credential_service = cli.InMemoryCredentialService() + artifact_service = InMemoryArtifactService() + session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() dummy_root = BaseAgent(name="root") session = await cli.run_input_file( @@ -189,6 +194,34 @@ async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None: ) +@pytest.mark.asyncio +async def test_run_cli_loads_services_module( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should load custom services from the agents directory.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": ["ping"]} + input_path = tmp_path / "input.json" + input_path.write_text(json.dumps(input_json)) + + loaded_dirs: list[str] = [] + monkeypatch.setattr( + cli, "load_services_module", lambda path: loaded_dirs.append(path) + ) + + agent_root = parent_dir / folder_name + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + ) + + assert loaded_dirs == [str(agent_root.resolve())] + + @pytest.mark.asyncio async def test_run_cli_app_uses_app_name_for_sessions( fake_app_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch @@ -197,15 +230,20 @@ async def test_run_cli_app_uses_app_name_for_sessions( parent_dir, folder_name, app_name = fake_app_agent created_app_names: List[str] = [] - original_session_cls = cli.InMemorySessionService - - class _SpySessionService(original_session_cls): + class _SpySessionService(InMemorySessionService): async def create_session(self, *, app_name: str, **kwargs: Any) -> Any: created_app_names.append(app_name) return await super().create_session(app_name=app_name, **kwargs) - monkeypatch.setattr(cli, "InMemorySessionService", _SpySessionService) + spy_session_service = _SpySessionService() + + def _session_factory(**_: Any) -> InMemorySessionService: + return spy_session_service + + monkeypatch.setattr( + cli, "create_session_service_from_options", _session_factory + ) input_json = {"state": {}, "queries": ["ping"]} input_path = tmp_path / "input_app.json" @@ -253,16 +291,89 @@ async def test_run_cli_save_session( assert "id" in data and "events" in data +def test_create_artifact_service_defaults_to_file(tmp_path: Path) -> None: + """Service factory should default to FileArtifactService when URI is unset.""" + service = create_artifact_service_from_options(base_dir=tmp_path) + assert isinstance(service, FileArtifactService) + expected_root = Path(tmp_path) / ".adk" / "artifacts" + assert service.root_dir == expected_root + assert expected_root.exists() + + +def test_create_artifact_service_per_agent_uses_shared_root( + tmp_path: Path, +) -> None: + """Multi-agent mode should still use a single file artifact service.""" + service = create_artifact_service_from_options( + base_dir=tmp_path, per_agent=True + ) + assert isinstance(service, FileArtifactService) + expected_root = Path(tmp_path) / ".adk" / "artifacts" + assert service.root_dir == expected_root + assert expected_root.exists() + + +def test_create_artifact_service_respects_memory_uri(tmp_path: Path) -> None: + """Service factory should honor memory:// URIs.""" + service = create_artifact_service_from_options( + base_dir=tmp_path, artifact_service_uri="memory://" + ) + assert isinstance(service, InMemoryArtifactService) + + +def test_create_artifact_service_accepts_file_uri(tmp_path: Path) -> None: + """Service factory should allow custom local roots via file:// URIs.""" + custom_root = tmp_path / "custom_artifacts" + service = create_artifact_service_from_options( + base_dir=tmp_path, artifact_service_uri=custom_root.as_uri() + ) + assert isinstance(service, FileArtifactService) + assert service.root_dir == custom_root + assert custom_root.exists() + + +def test_create_artifact_service_file_uri_rejects_per_agent(tmp_path: Path): + """file:// URIs are incompatible with per-agent mode.""" + custom_root = tmp_path / "custom" + with pytest.raises(ValueError, match="multi-agent"): + create_artifact_service_from_options( + base_dir=tmp_path, + artifact_service_uri=custom_root.as_uri(), + per_agent=True, + ) + + +@pytest.mark.asyncio +async def test_run_cli_accepts_memory_scheme( + fake_agent, tmp_path: Path +) -> None: + """run_cli should allow configuring in-memory services via memory:// URIs.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "noop.json" + input_path.write_text(json.dumps(input_json)) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + session_service_uri="memory://", + artifact_service_uri="memory://", + ) + + @pytest.mark.asyncio async def test_run_interactively_whitespace_and_exit( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - session_service = cli.InMemorySessionService() + session_service = InMemorySessionService() sess = await session_service.create_session(app_name="dummy", user_id="u") - artifact_service = cli.InMemoryArtifactService() - credential_service = cli.InMemoryCredentialService() + artifact_service = InMemoryArtifactService() + credential_service = InMemoryCredentialService() root_agent = BaseAgent(name="root") # fake user input: blank -> 'hello' -> 'exit' diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py new file mode 100644 index 0000000000..5ff92a076b --- /dev/null +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -0,0 +1,162 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for service factory helpers.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +import google.adk.cli.utils.service_factory as service_factory +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +import pytest + + +def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_session_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="sqlite:///test.db", + ) + + assert result is expected + registry.create_session_service.assert_called_once_with( + "sqlite:///test.db", + agents_dir=str(tmp_path), + per_agent=False, + ) + + +def test_create_session_service_per_agent_uri(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_session_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="memory://", + per_agent=True, + ) + + assert result is expected + registry.create_session_service.assert_called_once_with( + "memory://", agents_dir=str(tmp_path), per_agent=True + ) + + +@pytest.mark.parametrize("per_agent", [True, False]) +def test_create_session_service_defaults_to_memory( + tmp_path: Path, per_agent: bool +): + service = service_factory.create_session_service_from_options( + base_dir=tmp_path, + per_agent=per_agent, + ) + + assert isinstance(service, InMemorySessionService) + + +def test_create_session_service_fallbacks_to_database( + tmp_path: Path, monkeypatch +): + registry = Mock() + registry.create_session_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + service = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="sqlite+aiosqlite:///:memory:", + session_db_kwargs={"echo": True}, + ) + + assert isinstance(service, DatabaseSessionService) + assert service.db_engine.url.drivername == "sqlite+aiosqlite" + assert service.db_engine.echo is True + registry.create_session_service.assert_called_once_with( + "sqlite+aiosqlite:///:memory:", + agents_dir=str(tmp_path), + per_agent=False, + echo=True, + ) + + +@pytest.mark.parametrize("per_agent", [True, False]) +def test_create_artifact_service_uses_registry( + tmp_path: Path, monkeypatch, per_agent: bool +): + registry = Mock() + expected = object() + registry.create_artifact_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_artifact_service_from_options( + base_dir=tmp_path, + artifact_service_uri="gs://bucket/path", + per_agent=per_agent, + ) + + assert result is expected + registry.create_artifact_service.assert_called_once_with( + "gs://bucket/path", + agents_dir=str(tmp_path), + per_agent=per_agent, + ) + + +def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_memory_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="rag://my-corpus", + ) + + assert result is expected + registry.create_memory_service.assert_called_once_with( + "rag://my-corpus", + agents_dir=str(tmp_path), + ) + + +def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): + service = service_factory.create_memory_service_from_options( + base_dir=tmp_path + ) + + assert isinstance(service, InMemoryMemoryService) + + +def test_create_memory_service_raises_on_unknown_scheme( + tmp_path: Path, monkeypatch +): + registry = Mock() + registry.create_memory_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + with pytest.raises(ValueError): + service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="unknown://foo", + ) From f283027e9215fc64e4293074dd97584aef3b8c0b Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 25 Nov 2025 11:12:34 -0800 Subject: [PATCH 5/5] feat: expose service URI flags Adds the shared adk_services_options decorator to adk run and other commands so developers can pass session/artifact URIs from the CLI Has new warning for the unsupported memory service on adk run, and removes the legacy --session_db_url/--artifact_storage_uri flags with tests Co-authored-by: George Weale PiperOrigin-RevId: 836743358 --- src/google/adk/cli/cli_tools_click.py | 128 ++++++++++-------- src/google/adk/cli/utils/__init__.py | 2 + .../cli/utils/test_cli_tools_click.py | 87 +++++++++--- 3 files changed, 143 insertions(+), 74 deletions(-) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 529ee7319c..c4a13dd15f 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -24,6 +24,7 @@ import os from pathlib import Path import tempfile +import textwrap from typing import Optional import click @@ -354,7 +355,62 @@ def validate_exclusive(ctx, param, value): return value +def adk_services_options(): + """Decorator to add ADK services options to click commands.""" + + def decorator(func): + @click.option( + "--session_service_uri", + help=textwrap.dedent( + """\ + Optional. The URI of the session service. + - Leave unset to use the in-memory session service (default). + - Use 'agentengine://' to connect to Agent Engine + sessions. can either be the full qualified resource + name 'projects/abc/locations/us-central1/reasoningEngines/123' or + the resource id '123'. + - Use 'memory://' to run with the in-memory session service. + - Use 'sqlite://' to connect to a SQLite DB. + - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs.""" + ), + ) + @click.option( + "--artifact_service_uri", + type=str, + help=textwrap.dedent( + """\ + Optional. The URI of the artifact service. + - Leave unset to store artifacts under '.adk/artifacts' locally. + - Use 'gs://' to connect to the GCS artifact service. + - Use 'memory://' to force the in-memory artifact service. + - Use 'file://' to store artifacts in a custom local directory.""" + ), + default=None, + ) + @click.option( + "--memory_service_uri", + type=str, + help=textwrap.dedent("""\ + Optional. The URI of the memory service. + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Agent Engine + sessions. can either be the full qualified resource + name 'projects/abc/locations/us-central1/reasoningEngines/123' or + the resource id '123'. + - Use 'memory://' to force the in-memory memory service."""), + default=None, + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorator + + @main.command("run", cls=HelpfulCommand) +@adk_services_options() @click.option( "--save_session", type=bool, @@ -409,6 +465,9 @@ def cli_run( session_id: Optional[str], replay: Optional[str], resume: Optional[str], + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, + memory_service_uri: Optional[str] = None, ): """Runs an interactive CLI for a certain agent. @@ -420,6 +479,14 @@ def cli_run( """ logs.log_to_tmp_folder() + # Validation warning for memory_service_uri (not supported for adk run) + if memory_service_uri: + click.secho( + "WARNING: --memory_service_uri is not supported for adk run.", + fg="yellow", + err=True, + ) + agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) @@ -431,6 +498,8 @@ def cli_run( saved_session_file=resume, save_session=save_session, session_id=session_id, + session_service_uri=session_service_uri, + artifact_service_uri=artifact_service_uri, ) ) @@ -865,55 +934,6 @@ def wrapper(*args, **kwargs): return decorator -def adk_services_options(): - """Decorator to add ADK services options to click commands.""" - - def decorator(func): - @click.option( - "--session_service_uri", - help=( - """Optional. The URI of the session service. - - Use 'agentengine://' to connect to Agent Engine - sessions. can either be the full qualified resource - name 'projects/abc/locations/us-central1/reasoningEngines/123' or - the resource id '123'. - - Use 'sqlite://' to connect to an aio-sqlite - based session service, which is good for local development. - - Use 'postgresql://:@:/' - to connect to a PostgreSQL DB. - - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls - for more details on other database URIs supported by SQLAlchemy.""" - ), - ) - @click.option( - "--artifact_service_uri", - type=str, - help=( - "Optional. The URI of the artifact service," - " supported URIs: gs:// for GCS artifact service." - ), - default=None, - ) - @click.option( - "--memory_service_uri", - type=str, - help=("""Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service. - - Use 'agentengine://' to connect to Agent Engine - sessions. can either be the full qualified resource - name 'projects/abc/locations/us-central1/reasoningEngines/123' or - the resource id '123'."""), - default=None, - ) - @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper - - return decorator - - def deprecated_adk_services_options(): """Deprecated ADK services options.""" @@ -921,7 +941,7 @@ def warn(alternative_param, ctx, param, value): if value: click.echo( click.style( - f"WARNING: Deprecated option {param.name} is used. Please use" + f"WARNING: Deprecated option --{param.name} is used. Please use" f" {alternative_param} instead.", fg="yellow", ), @@ -1116,6 +1136,8 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ + session_service_uri = session_service_uri or session_db_url + artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1140,8 +1162,6 @@ async def _lifespan(app: FastAPI): fg="green", ) - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri app = get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, @@ -1215,10 +1235,10 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - logs.setup_adk_logger(getattr(logging, log_level.upper())) - session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri + logs.setup_adk_logger(getattr(logging, log_level.upper())) + config = uvicorn.Config( get_fast_api_app( agents_dir=agents_dir, diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/__init__.py index 8aa11b252b..1800f5d04c 100644 --- a/src/google/adk/cli/utils/__init__.py +++ b/src/google/adk/cli/utils/__init__.py @@ -18,8 +18,10 @@ from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent +from .dot_adk_folder import DotAdkFolder from .state import create_empty_state __all__ = [ 'create_empty_state', + 'DotAdkFolder', ] diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index be9015ca87..95b561e57b 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -76,8 +76,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 # Fixtures @pytest.fixture(autouse=True) -def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None: +def _mute_click(request, monkeypatch: pytest.MonkeyPatch) -> None: """Suppress click output during tests.""" + # Allow tests to opt-out of muting by using the 'unmute_click' marker + if "unmute_click" in request.keywords: + return monkeypatch.setattr(click, "echo", lambda *a, **k: None) # Keep secho for error messages # monkeypatch.setattr(click, "secho", lambda *a, **k: None) @@ -121,32 +124,70 @@ def test_cli_create_cmd_invokes_run_cmd( cli_tools_click.main, ["create", "--model", "gemini", "--api_key", "key123", str(app_dir)], ) - assert result.exit_code == 0 + assert result.exit_code == 0, (result.output, repr(result.exception)) assert rec.calls, "cli_create.run_cmd must be called" # cli run -@pytest.mark.asyncio -async def test_cli_run_invokes_run_cli( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch +@pytest.mark.parametrize( + "cli_args,expected_session_uri,expected_artifact_uri", + [ + pytest.param( + [ + "--session_service_uri", + "memory://", + "--artifact_service_uri", + "memory://", + ], + "memory://", + "memory://", + id="memory_scheme_uris", + ), + pytest.param( + [], + None, + None, + id="default_uris_none", + ), + ], +) +def test_cli_run_service_uris( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + cli_args: list, + expected_session_uri: str, + expected_artifact_uri: str, ) -> None: - """`adk run` should call run_cli via asyncio.run with correct parameters.""" - rec = _Recorder() - monkeypatch.setattr(cli_tools_click, "run_cli", lambda **kwargs: rec(kwargs)) - monkeypatch.setattr( - cli_tools_click.asyncio, "run", lambda coro: coro - ) # pass-through - - # create dummy agent directory + """`adk run` should forward service URIs correctly to run_cli.""" agent_dir = tmp_path / "agent" agent_dir.mkdir() (agent_dir / "__init__.py").touch() (agent_dir / "agent.py").touch() + # Capture the coroutine's locals before closing it + captured_locals = [] + + def capture_asyncio_run(coro): + # Extract the locals before closing the coroutine + if coro.cr_frame is not None: + captured_locals.append(dict(coro.cr_frame.f_locals)) + coro.close() # Properly close the coroutine to avoid warnings + + monkeypatch.setattr(cli_tools_click.asyncio, "run", capture_asyncio_run) + runner = CliRunner() - result = runner.invoke(cli_tools_click.main, ["run", str(agent_dir)]) - assert result.exit_code == 0 - assert rec.calls and rec.calls[0][0][0]["agent_folder_name"] == "agent" + result = runner.invoke( + cli_tools_click.main, + ["run", *cli_args, str(agent_dir)], + ) + assert result.exit_code == 0, (result.output, repr(result.exception)) + assert len(captured_locals) == 1, "Expected asyncio.run to be called once" + + # Verify the kwargs passed to run_cli + coro_locals = captured_locals[0] + assert coro_locals.get("session_service_uri") == expected_session_uri + assert coro_locals.get("artifact_service_uri") == expected_artifact_uri + assert coro_locals["agent_folder_name"] == "agent" # cli deploy cloud_run @@ -520,10 +561,13 @@ def test_cli_web_passes_service_uris( assert called_kwargs.get("memory_service_uri") == "rag://mycorpus" -def test_cli_web_passes_deprecated_uris( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder +@pytest.mark.unmute_click +def test_cli_web_warns_and_maps_deprecated_uris( + tmp_path: Path, + _patch_uvicorn: _Recorder, + monkeypatch: pytest.MonkeyPatch, ) -> None: - """`adk web` should use deprecated URIs if new ones are not provided.""" + """`adk web` should accept deprecated URI flags with warnings.""" agents_dir = tmp_path / "agents" agents_dir.mkdir() @@ -542,11 +586,14 @@ def test_cli_web_passes_deprecated_uris( "gs://deprecated", ], ) + assert result.exit_code == 0 - assert mock_get_app.calls called_kwargs = mock_get_app.calls[0][1] assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db" assert called_kwargs.get("artifact_service_uri") == "gs://deprecated" + # Check output for deprecation warnings (CliRunner captures both stdout and stderr) + assert "--session_db_url" in result.output + assert "--artifact_storage_uri" in result.output def test_cli_eval_with_eval_set_file_path(