diff --git a/pyproject.toml b/pyproject.toml index 1af967046f..ad1da6fff1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "PyYAML>=6.0.2, <7.0.0", # For APIHubToolset. "aiosqlite>=0.21.0", # For SQLite database "anyio>=4.9.0, <5.0.0", # For MCP Session Manager - "authlib>=1.5.1, <2.0.0", # For RestAPI Tool + "authlib>=1.6.6, <2.0.0", # For RestAPI Tool "click>=8.1.8, <9.0.0", # For CLI tools "fastapi>=0.115.0, <0.124.0", # FastAPI framework "google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery @@ -46,7 +46,7 @@ dependencies = [ "google-genai>=1.56.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation - "mcp>=1.10.0, <2.0.0", # For MCP Toolset + "mcp>=1.23.0, <2.0.0", # For MCP Toolset "opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package. "opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0", diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c80115be8f..a4b4418f55 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -50,28 +50,44 @@ ) -def _apply_feature_overrides(enable_features: tuple[str, ...]) -> None: +def _apply_feature_overrides( + *, + enable_features: tuple[str, ...] = (), + disable_features: tuple[str, ...] = (), +) -> None: """Apply feature overrides from CLI flags. Args: enable_features: Tuple of feature names to enable. + disable_features: Tuple of feature names to disable. """ + feature_overrides: dict[str, bool] = {} + for features_str in enable_features: for feature_name_str in features_str.split(","): feature_name_str = feature_name_str.strip() - if not feature_name_str: - continue - try: - feature_name = FeatureName(feature_name_str) - override_feature_enabled(feature_name, True) - except ValueError: - valid_names = ", ".join(f.value for f in FeatureName) - click.secho( - f"WARNING: Unknown feature name '{feature_name_str}'. " - f"Valid names are: {valid_names}", - fg="yellow", - err=True, - ) + if feature_name_str: + feature_overrides[feature_name_str] = True + + for features_str in disable_features: + for feature_name_str in features_str.split(","): + feature_name_str = feature_name_str.strip() + if feature_name_str: + feature_overrides[feature_name_str] = False + + # Apply all overrides + for feature_name_str, enabled in feature_overrides.items(): + try: + feature_name = FeatureName(feature_name_str) + override_feature_enabled(feature_name, enabled) + except ValueError: + valid_names = ", ".join(f.value for f in FeatureName) + click.secho( + f"WARNING: Unknown feature name '{feature_name_str}'. " + f"Valid names are: {valid_names}", + fg="yellow", + err=True, + ) def feature_options(): @@ -88,11 +104,25 @@ def decorator(func): ), multiple=True, ) + @click.option( + "--disable_features", + help=( + "Optional. Comma-separated list of feature names to disable. " + "This provides an alternative to environment variables for " + "disabling features. Example: " + "--disable_features=JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING" + ), + multiple=True, + ) @functools.wraps(func) def wrapper(*args, **kwargs): enable_features = kwargs.pop("enable_features", ()) - if enable_features: - _apply_feature_overrides(enable_features) + disable_features = kwargs.pop("disable_features", ()) + if enable_features or disable_features: + _apply_feature_overrides( + enable_features=enable_features, + disable_features=disable_features, + ) return func(*args, **kwargs) return wrapper diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py index c03ac10b85..d8903ece14 100644 --- a/src/google/adk/cli/utils/service_factory.py +++ b/src/google/adk/cli/utils/service_factory.py @@ -19,6 +19,9 @@ from pathlib import Path from typing import Any from typing import Optional +from urllib.parse import parse_qsl +from urllib.parse import urlsplit +from urllib.parse import urlunsplit from ...artifacts.base_artifact_service import BaseArtifactService from ...memory.base_memory_service import BaseMemoryService @@ -42,6 +45,41 @@ _KUBERNETES_HOST_ENV = "KUBERNETES_SERVICE_HOST" +def _redact_uri_for_log(uri: str) -> str: + """Returns a safe-to-log representation of a URI. + + Redacts user info (username/password) and query parameter values. + """ + if not uri or not uri.strip(): + return "" + sanitized = uri.replace("\r", "\\r").replace("\n", "\\n") + if "://" not in sanitized: + return "" + try: + parsed = urlsplit(sanitized) + except ValueError: + return "" + + if not parsed.scheme: + return "" + + netloc = parsed.netloc + if "@" in netloc: + _, netloc = netloc.rsplit("@", 1) + + if parsed.query: + try: + redacted_pairs = parse_qsl(parsed.query, keep_blank_values=True) + except ValueError: + query = "" + else: + query = "&".join(f"{key}=" for key, _ in redacted_pairs) + else: + query = "" + + return urlunsplit((parsed.scheme, netloc, parsed.path, query, "")) + + def _is_cloud_run() -> bool: """Returns True when running in Cloud Run.""" return bool(os.environ.get(_CLOUD_RUN_SERVICE_ENV)) @@ -148,7 +186,10 @@ def create_session_service_from_options( kwargs.update(session_db_kwargs) if session_service_uri: - logger.info("Using session service URI: %s", session_service_uri) + logger.info( + "Using session service URI: %s", + _redact_uri_for_log(session_service_uri), + ) service = registry.create_session_service(session_service_uri, **kwargs) if service is not None: return service @@ -162,7 +203,7 @@ def create_session_service_from_options( fallback_kwargs.pop("agents_dir", None) logger.info( "Falling back to DatabaseSessionService for URI: %s", - session_service_uri, + _redact_uri_for_log(session_service_uri), ) return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) @@ -208,13 +249,18 @@ def create_memory_service_from_options( registry = get_service_registry() if memory_service_uri: - logger.info("Using memory service URI: %s", memory_service_uri) + logger.info( + "Using memory service URI: %s", _redact_uri_for_log(memory_service_uri) + ) service = registry.create_memory_service( memory_service_uri, agents_dir=str(base_path), ) if service is None: - raise ValueError(f"Unsupported memory service URI: {memory_service_uri}") + raise ValueError( + "Unsupported memory service URI: %s" + % _redact_uri_for_log(memory_service_uri) + ) return service logger.info("Using in-memory memory service") @@ -235,7 +281,10 @@ def create_artifact_service_from_options( registry = get_service_registry() if artifact_service_uri: - logger.info("Using artifact service URI: %s", artifact_service_uri) + logger.info( + "Using artifact service URI: %s", + _redact_uri_for_log(artifact_service_uri), + ) service = registry.create_artifact_service( artifact_service_uri, agents_dir=str(base_path), @@ -243,11 +292,12 @@ def create_artifact_service_from_options( if service is None: if strict_uri: raise ValueError( - f"Unsupported artifact service URI: {artifact_service_uri}" + "Unsupported artifact service URI: %s" + % _redact_uri_for_log(artifact_service_uri) ) return _create_in_memory_artifact_service( "Unsupported artifact service URI: %s, falling back to in-memory", - artifact_service_uri, + _redact_uri_for_log(artifact_service_uri), ) return service diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index 2e3e978bc7..bfe3497c86 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -23,7 +23,6 @@ from typing_extensions import override from ..agents.invocation_context import InvocationContext -from ..utils.feature_decorator import experimental from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult @@ -32,7 +31,6 @@ logger = logging.getLogger('google_adk.' + __name__) -@experimental class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): """A code executor that uses Agent Engine Code Execution Sandbox to execute code. diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index f6705c1de9..a7899e788d 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -16,6 +16,7 @@ import base64 import copy +import importlib.util import json import logging import mimetypes @@ -32,6 +33,7 @@ from typing import Literal from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING from typing import TypedDict from typing import Union from urllib.parse import urlparse @@ -39,20 +41,12 @@ import warnings from google.genai import types -import litellm -from litellm import acompletion -from litellm import ChatCompletionAssistantMessage -from litellm import ChatCompletionAssistantToolCall -from litellm import ChatCompletionMessageToolCall -from litellm import ChatCompletionSystemMessage -from litellm import ChatCompletionToolMessage -from litellm import ChatCompletionUserMessage -from litellm import completion -from litellm import CustomStreamWrapper -from litellm import Function -from litellm import Message -from litellm import ModelResponse -from litellm import OpenAIMessageContent + +if not TYPE_CHECKING and importlib.util.find_spec("litellm") is None: + raise ImportError( + "LiteLLM support requires: pip install google-adk[extensions]" + ) + from pydantic import BaseModel from pydantic import Field from typing_extensions import override @@ -61,8 +55,36 @@ from .llm_request import LlmRequest from .llm_response import LlmResponse -# This will add functions to prompts if functions are provided. -litellm.add_function_to_prompt = True +if TYPE_CHECKING: + import litellm + from litellm import acompletion + from litellm import ChatCompletionAssistantMessage + from litellm import ChatCompletionAssistantToolCall + from litellm import ChatCompletionMessageToolCall + from litellm import ChatCompletionSystemMessage + from litellm import ChatCompletionToolMessage + from litellm import ChatCompletionUserMessage + from litellm import completion + from litellm import CustomStreamWrapper + from litellm import Function + from litellm import Message + from litellm import ModelResponse + from litellm import OpenAIMessageContent +else: + litellm = None + acompletion = None + ChatCompletionAssistantMessage = None + ChatCompletionAssistantToolCall = None + ChatCompletionMessageToolCall = None + ChatCompletionSystemMessage = None + ChatCompletionToolMessage = None + ChatCompletionUserMessage = None + completion = None + CustomStreamWrapper = None + Function = None + Message = None + ModelResponse = None + OpenAIMessageContent = None logger = logging.getLogger("google_adk." + __name__) @@ -109,6 +131,50 @@ "before a response was recorded)." ) +_LITELLM_IMPORTED = False +_LITELLM_GLOBAL_SYMBOLS = ( + "ChatCompletionAssistantMessage", + "ChatCompletionAssistantToolCall", + "ChatCompletionMessageToolCall", + "ChatCompletionSystemMessage", + "ChatCompletionToolMessage", + "ChatCompletionUserMessage", + "CustomStreamWrapper", + "Function", + "Message", + "ModelResponse", + "OpenAIMessageContent", + "acompletion", + "completion", +) + + +def _ensure_litellm_imported() -> None: + """Imports LiteLLM with safe defaults. + + LiteLLM defaults to DEV mode, which auto-loads a local `.env` at import time. + ADK should not implicitly load `.env` just because LiteLLM is installed. + + Users can opt into LiteLLM's default behavior by setting LITELLM_MODE=DEV. + """ + global _LITELLM_IMPORTED + if _LITELLM_IMPORTED: + return + + # https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py#L80-L82 + os.environ.setdefault("LITELLM_MODE", "PRODUCTION") + + import litellm as litellm_module + + litellm_module.add_function_to_prompt = True + + globals()["litellm"] = litellm_module + for symbol in _LITELLM_GLOBAL_SYMBOLS: + globals()[symbol] = getattr(litellm_module, symbol) + + _redirect_litellm_loggers_to_stdout() + _LITELLM_IMPORTED = True + def _map_finish_reason( finish_reason: Any, @@ -344,6 +410,7 @@ async def acompletion( Returns: The model response as a message. """ + _ensure_litellm_imported() return await acompletion( model=model, @@ -367,6 +434,7 @@ def completion( Returns: The response from the model. """ + _ensure_litellm_imported() return completion( model=model, @@ -513,6 +581,7 @@ async def _content_to_message_param( Returns: A litellm Message, a list of litellm Messages. """ + _ensure_litellm_imported() tool_messages: list[Message] = [] non_tool_parts: list[types.Part] = [] @@ -622,6 +691,8 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: if not messages: return messages + _ensure_litellm_imported() + healed_messages: List[Message] = [] pending_tool_call_ids: List[str] = [] @@ -691,6 +762,7 @@ async def _get_content( Returns: The litellm content. """ + _ensure_litellm_imported() parts_list = list(parts) if len(parts_list) == 1: @@ -925,6 +997,7 @@ def _build_tool_call_from_json_dict( candidate: Any, *, index: int ) -> Optional[ChatCompletionMessageToolCall]: """Creates a tool call object from JSON content embedded in text.""" + _ensure_litellm_imported() if not isinstance(candidate, dict): return None @@ -972,11 +1045,12 @@ def _parse_tool_calls_from_text( text_block: str, ) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: """Extracts inline JSON tool calls from LiteLLM text responses.""" - tool_calls = [] if not text_block: return tool_calls, None + _ensure_litellm_imported() + remainder_segments = [] cursor = 0 text_length = len(text_block) @@ -1014,7 +1088,6 @@ def _split_message_content_and_tool_calls( message: Message, ) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]: """Returns message content and tool calls, parsing inline JSON when needed.""" - existing_tool_calls = message.get("tool_calls") or [] normalized_tool_calls = ( list(existing_tool_calls) if existing_tool_calls else [] @@ -1180,6 +1253,7 @@ def _model_response_to_chunk( Yields: A tuple of text or function or usage metadata chunk and finish reason. """ + _ensure_litellm_imported() message = None if response.get("choices", None): @@ -1255,6 +1329,7 @@ def _model_response_to_generate_content_response( Returns: The LlmResponse. """ + _ensure_litellm_imported() message = None finish_reason = None @@ -1313,6 +1388,7 @@ def _message_to_generate_content_response( Returns: The LlmResponse. """ + _ensure_litellm_imported() parts: List[types.Part] = [] if not thought_parts: @@ -1440,6 +1516,8 @@ async def _get_completion_inputs( The litellm inputs (message list, tool dictionary, response format and generation params). """ + _ensure_litellm_imported() + # Determine provider for file handling provider = _get_provider_from_model(model) @@ -1665,11 +1743,6 @@ def _redirect_litellm_loggers_to_stdout() -> None: handler.stream = sys.stdout -# Redirect LiteLLM loggers to stdout immediately after import to ensure -# INFO-level logs are not incorrectly treated as errors in cloud environments. -_redirect_litellm_loggers_to_stdout() - - class LiteLlm(BaseLlm): """Wrapper around litellm. @@ -1732,6 +1805,7 @@ async def generate_content_async( Yields: LlmResponse: The model response. """ + _ensure_litellm_imported() self._maybe_append_user_content(llm_request) _append_fallback_user_content_if_missing(llm_request) diff --git a/tests/unittests/cli/test_cli_feature_options.py b/tests/unittests/cli/test_cli_feature_options.py index 70bfec2dda..8a507f5c75 100644 --- a/tests/unittests/cli/test_cli_feature_options.py +++ b/tests/unittests/cli/test_cli_feature_options.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for --enable_features CLI option.""" - from __future__ import annotations import click @@ -42,45 +40,96 @@ class TestApplyFeatureOverrides: def test_single_feature(self): """Single feature name is applied correctly.""" - _apply_feature_overrides(("JSON_SCHEMA_FOR_FUNC_DECL",)) + _apply_feature_overrides(enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) def test_comma_separated_features(self): """Comma-separated feature names are applied correctly.""" - _apply_feature_overrides(( - "JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING", - )) + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING",) + ) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) assert is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) def test_multiple_flag_values(self): """Multiple --enable_features flags are applied correctly.""" - _apply_feature_overrides(( - "JSON_SCHEMA_FOR_FUNC_DECL", - "PROGRESSIVE_SSE_STREAMING", - )) + _apply_feature_overrides( + enable_features=( + "JSON_SCHEMA_FOR_FUNC_DECL", + "PROGRESSIVE_SSE_STREAMING", + ) + ) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) assert is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) def test_whitespace_handling(self): """Whitespace around feature names is stripped.""" - _apply_feature_overrides((" JSON_SCHEMA_FOR_FUNC_DECL , COMPUTER_USE ",)) + _apply_feature_overrides( + enable_features=(" JSON_SCHEMA_FOR_FUNC_DECL , COMPUTER_USE ",) + ) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) assert is_feature_enabled(FeatureName.COMPUTER_USE) def test_empty_string_ignored(self): """Empty strings in the list are ignored.""" - _apply_feature_overrides(("",)) + _apply_feature_overrides(enable_features=("",)) # No error should be raised def test_unknown_feature_warns(self, capsys): """Unknown feature names emit a warning.""" - _apply_feature_overrides(("UNKNOWN_FEATURE_XYZ",)) + _apply_feature_overrides(enable_features=("UNKNOWN_FEATURE_XYZ",)) captured = capsys.readouterr() assert "WARNING" in captured.err assert "UNKNOWN_FEATURE_XYZ" in captured.err assert "Valid names are:" in captured.err + def test_single_disable_feature(self): + """Single feature name is disabled correctly.""" + # First enable a feature + _apply_feature_overrides(enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + # Then disable it + _apply_feature_overrides(disable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) + assert not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + def test_comma_separated_disable_features(self): + """Comma-separated feature names are disabled correctly.""" + # First enable features + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING",) + ) + + # Then disable them + _apply_feature_overrides( + disable_features=( + "JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING", + ) + ) + assert not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert not is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + + def test_disable_overrides_enable(self): + """Disable is applied after enable, so disable wins for same feature.""" + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",), + disable_features=("JSON_SCHEMA_FOR_FUNC_DECL",), + ) + # disable_features is processed after enable_features + assert not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + def test_enable_and_disable_different_features(self): + """Enable and disable can be used together for different features.""" + # First enable a feature that we'll disable + _apply_feature_overrides(enable_features=("PROGRESSIVE_SSE_STREAMING",)) + + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",), + disable_features=("PROGRESSIVE_SSE_STREAMING",), + ) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert not is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + class TestFeatureOptionsDecorator: """Tests for feature_options decorator.""" @@ -195,3 +244,64 @@ def my_test_command(): "my_test_command" in my_test_command.name or my_test_command.callback.__name__ == "my_test_command" ) + + def test_decorator_adds_disable_features_option(self): + """Decorator adds --disable_features option to command.""" + + @click.command() + @feature_options() + def test_cmd(): + pass + + runner = CliRunner() + result = runner.invoke(test_cmd, ["--help"]) + assert "--disable_features" in result.output + + def test_disable_features_applied_before_command(self): + """Features are disabled before the command function runs.""" + # First enable the feature via override + _apply_feature_overrides(enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) + + feature_was_disabled = [] + + @click.command() + @feature_options() + def test_cmd(): + feature_was_disabled.append( + not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + ["--disable_features=JSON_SCHEMA_FOR_FUNC_DECL"], + catch_exceptions=False, + ) + assert feature_was_disabled == [True] + + def test_enable_and_disable_together(self): + """Both --enable_features and --disable_features work together.""" + feature_states = [] + + @click.command() + @feature_options() + def test_cmd(): + feature_states.append( + is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + feature_states.append( + is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + [ + "--enable_features=JSON_SCHEMA_FOR_FUNC_DECL", + "--disable_features=PROGRESSIVE_SSE_STREAMING", + ], + catch_exceptions=False, + ) + # JSON_SCHEMA_FOR_FUNC_DECL should be enabled + # PROGRESSIVE_SSE_STREAMING should be disabled + assert feature_states == [True, False] diff --git a/tests/unittests/cli/test_cli_tools_click_option_mismatch.py b/tests/unittests/cli/test_cli_tools_click_option_mismatch.py index 3c67e9ae39..9b81345312 100644 --- a/tests/unittests/cli/test_cli_tools_click_option_mismatch.py +++ b/tests/unittests/cli/test_cli_tools_click_option_mismatch.py @@ -95,7 +95,10 @@ def test_adk_run(): assert run_command is not None, "Run command not found" _check_options_in_parameters( - run_command, cli_run.callback, "run", ignore_params={"enable_features"} + run_command, + cli_run.callback, + "run", + ignore_params={"enable_features", "disable_features"}, ) @@ -105,7 +108,10 @@ def test_adk_eval(): assert eval_command is not None, "Eval command not found" _check_options_in_parameters( - eval_command, cli_eval.callback, "eval", ignore_params={"enable_features"} + eval_command, + cli_eval.callback, + "eval", + ignore_params={"enable_features", "disable_features"}, ) @@ -118,7 +124,7 @@ def test_adk_web(): web_command, cli_web.callback, "web", - ignore_params={"verbose", "enable_features"}, + ignore_params={"verbose", "enable_features", "disable_features"}, ) @@ -131,7 +137,7 @@ def test_adk_api_server(): api_server_command, cli_api_server.callback, "api_server", - ignore_params={"verbose", "enable_features"}, + ignore_params={"verbose", "enable_features", "disable_features"}, ) diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 87b567be73..ad9a238985 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -16,12 +16,14 @@ from __future__ import annotations +import logging import os from pathlib import Path -from unittest.mock import Mock +from unittest import mock from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.service_registry import ServiceRegistry from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService import google.adk.cli.utils.service_factory as service_factory from google.adk.memory.in_memory_memory_service import InMemoryMemoryService @@ -31,7 +33,7 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_session_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -48,6 +50,87 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): ) +def test_create_session_service_logs_redacted_uri( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) + registry.create_session_service.return_value = object() + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + session_service_uri = ( + "postgresql://user:supersecret@localhost:5432/dbname?sslmode=require" + ) + caplog.set_level(logging.INFO, logger=service_factory.logger.name) + + service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri=session_service_uri, + ) + + assert "supersecret" not in caplog.text + assert "sslmode=require" not in caplog.text + assert "localhost:5432" in caplog.text + + +def test_redact_uri_for_log_removes_credentials_with_at_in_password() -> None: + uri = "postgresql://user:super@secret@localhost:5432/dbname" + + assert ( + service_factory._redact_uri_for_log(uri) + == "postgresql://localhost:5432/dbname" + ) + + +def test_redact_uri_for_log_preserves_host_when_no_credentials() -> None: + uri = "postgresql://localhost:5432/dbname?sslmode=require&password=secret" + + redacted = service_factory._redact_uri_for_log(uri) + + assert redacted.startswith("postgresql://localhost:5432/dbname?") + assert "require" not in redacted + assert "secret" not in redacted + assert "sslmode=" in redacted + assert "password=" in redacted + + +def test_redact_uri_for_log_redacts_when_parse_qsl_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _raise_value_error(*_args, **_kwargs): + raise ValueError("bad query") + + monkeypatch.setattr(service_factory, "parse_qsl", _raise_value_error) + + uri = "postgresql://user:pass@localhost:5432/dbname?sslmode=require" + redacted = service_factory._redact_uri_for_log(uri) + + assert "pass" not in redacted + assert "require" not in redacted + assert redacted.endswith("?") + + +def test_redact_uri_for_log_escapes_crlf() -> None: + uri = ( + "postgresql://user:pass@localhost:5432/dbname\rINJECT\nINJECT" + "?sslmode=require" + ) + + redacted = service_factory._redact_uri_for_log(uri) + + assert "\r" not in redacted + assert "\n" not in redacted + assert "\\rINJECT\\nINJECT" in redacted + + +def test_redact_uri_for_log_returns_scheme_missing_without_separator() -> None: + assert ( + service_factory._redact_uri_for_log("user:pass@localhost:5432/dbname") + == "" + ) + + @pytest.mark.asyncio async def test_create_session_service_defaults_to_per_agent_sqlite( tmp_path: Path, @@ -88,7 +171,7 @@ async def test_create_session_service_respects_app_name_mapping( def test_create_session_service_fallbacks_to_database( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_session_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -109,7 +192,7 @@ def test_create_session_service_fallbacks_to_database( def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_artifact_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -129,7 +212,7 @@ def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): def test_create_artifact_service_raises_on_unknown_scheme_when_strict( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_artifact_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -142,7 +225,7 @@ def test_create_artifact_service_raises_on_unknown_scheme_when_strict( def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_memory_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -170,7 +253,7 @@ def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): def test_create_memory_service_raises_on_unknown_scheme( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_memory_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) diff --git a/tests/unittests/models/test_litellm_import.py b/tests/unittests/models/test_litellm_import.py new file mode 100644 index 0000000000..179dd4703e --- /dev/null +++ b/tests/unittests/models/test_litellm_import.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import os +import subprocess +import sys + +import pytest + + +def _subprocess_env() -> dict[str, str]: + env = dict(os.environ) + src_path = os.path.join(os.getcwd(), "src") + pythonpath = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = ( + f"{src_path}{os.pathsep}{pythonpath}" if pythonpath else src_path + ) + return env + + +def test_importing_models_does_not_import_litellm_or_set_mode(): + env = _subprocess_env() + env.pop("LITELLM_MODE", None) + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import os, sys\n" + "import google.adk.models\n" + "print('litellm' in sys.modules)\n" + "print(os.environ.get('LITELLM_MODE'))\n" + ), + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + stdout_lines = result.stdout.strip().splitlines() + assert stdout_lines == ["False", "None"] + + +def test_ensure_litellm_imported_defaults_to_production(): + if importlib.util.find_spec("litellm") is None: + pytest.skip("litellm is not installed") + + env = _subprocess_env() + env.pop("LITELLM_MODE", None) + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import os\n" + "from google.adk.models.lite_llm import" + " _ensure_litellm_imported\n" + "_ensure_litellm_imported()\n" + "print(os.environ.get('LITELLM_MODE'))\n" + ), + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + assert result.stdout.strip() == "PRODUCTION" + + +def test_ensure_litellm_imported_does_not_override(): + if importlib.util.find_spec("litellm") is None: + pytest.skip("litellm is not installed") + + env = _subprocess_env() + env["LITELLM_MODE"] = "DEV" + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import os\n" + "from google.adk.models.lite_llm import" + " _ensure_litellm_imported\n" + "_ensure_litellm_imported()\n" + "print(os.environ.get('LITELLM_MODE'))\n" + ), + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + assert result.stdout.strip() == "DEV"