From 53b67ce6340f3f3f8c3d732f9f7811e445c76359 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 20 Jan 2026 10:45:08 -0800 Subject: [PATCH 1/5] feat: Add `--disable_features` CLI option to ADK CLI This flag can be used to override default feature enable state. Co-authored-by: Xuan Yang PiperOrigin-RevId: 858659818 --- src/google/adk/cli/cli_tools_click.py | 62 +++++--- .../unittests/cli/test_cli_feature_options.py | 136 ++++++++++++++++-- .../test_cli_tools_click_option_mismatch.py | 14 +- 3 files changed, 179 insertions(+), 33 deletions(-) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c80115be8f..a4b4418f55 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -50,28 +50,44 @@ ) -def _apply_feature_overrides(enable_features: tuple[str, ...]) -> None: +def _apply_feature_overrides( + *, + enable_features: tuple[str, ...] = (), + disable_features: tuple[str, ...] = (), +) -> None: """Apply feature overrides from CLI flags. Args: enable_features: Tuple of feature names to enable. + disable_features: Tuple of feature names to disable. """ + feature_overrides: dict[str, bool] = {} + for features_str in enable_features: for feature_name_str in features_str.split(","): feature_name_str = feature_name_str.strip() - if not feature_name_str: - continue - try: - feature_name = FeatureName(feature_name_str) - override_feature_enabled(feature_name, True) - except ValueError: - valid_names = ", ".join(f.value for f in FeatureName) - click.secho( - f"WARNING: Unknown feature name '{feature_name_str}'. " - f"Valid names are: {valid_names}", - fg="yellow", - err=True, - ) + if feature_name_str: + feature_overrides[feature_name_str] = True + + for features_str in disable_features: + for feature_name_str in features_str.split(","): + feature_name_str = feature_name_str.strip() + if feature_name_str: + feature_overrides[feature_name_str] = False + + # Apply all overrides + for feature_name_str, enabled in feature_overrides.items(): + try: + feature_name = FeatureName(feature_name_str) + override_feature_enabled(feature_name, enabled) + except ValueError: + valid_names = ", ".join(f.value for f in FeatureName) + click.secho( + f"WARNING: Unknown feature name '{feature_name_str}'. " + f"Valid names are: {valid_names}", + fg="yellow", + err=True, + ) def feature_options(): @@ -88,11 +104,25 @@ def decorator(func): ), multiple=True, ) + @click.option( + "--disable_features", + help=( + "Optional. Comma-separated list of feature names to disable. " + "This provides an alternative to environment variables for " + "disabling features. Example: " + "--disable_features=JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING" + ), + multiple=True, + ) @functools.wraps(func) def wrapper(*args, **kwargs): enable_features = kwargs.pop("enable_features", ()) - if enable_features: - _apply_feature_overrides(enable_features) + disable_features = kwargs.pop("disable_features", ()) + if enable_features or disable_features: + _apply_feature_overrides( + enable_features=enable_features, + disable_features=disable_features, + ) return func(*args, **kwargs) return wrapper diff --git a/tests/unittests/cli/test_cli_feature_options.py b/tests/unittests/cli/test_cli_feature_options.py index 70bfec2dda..8a507f5c75 100644 --- a/tests/unittests/cli/test_cli_feature_options.py +++ b/tests/unittests/cli/test_cli_feature_options.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for --enable_features CLI option.""" - from __future__ import annotations import click @@ -42,45 +40,96 @@ class TestApplyFeatureOverrides: def test_single_feature(self): """Single feature name is applied correctly.""" - _apply_feature_overrides(("JSON_SCHEMA_FOR_FUNC_DECL",)) + _apply_feature_overrides(enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) def test_comma_separated_features(self): """Comma-separated feature names are applied correctly.""" - _apply_feature_overrides(( - "JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING", - )) + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING",) + ) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) assert is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) def test_multiple_flag_values(self): """Multiple --enable_features flags are applied correctly.""" - _apply_feature_overrides(( - "JSON_SCHEMA_FOR_FUNC_DECL", - "PROGRESSIVE_SSE_STREAMING", - )) + _apply_feature_overrides( + enable_features=( + "JSON_SCHEMA_FOR_FUNC_DECL", + "PROGRESSIVE_SSE_STREAMING", + ) + ) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) assert is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) def test_whitespace_handling(self): """Whitespace around feature names is stripped.""" - _apply_feature_overrides((" JSON_SCHEMA_FOR_FUNC_DECL , COMPUTER_USE ",)) + _apply_feature_overrides( + enable_features=(" JSON_SCHEMA_FOR_FUNC_DECL , COMPUTER_USE ",) + ) assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) assert is_feature_enabled(FeatureName.COMPUTER_USE) def test_empty_string_ignored(self): """Empty strings in the list are ignored.""" - _apply_feature_overrides(("",)) + _apply_feature_overrides(enable_features=("",)) # No error should be raised def test_unknown_feature_warns(self, capsys): """Unknown feature names emit a warning.""" - _apply_feature_overrides(("UNKNOWN_FEATURE_XYZ",)) + _apply_feature_overrides(enable_features=("UNKNOWN_FEATURE_XYZ",)) captured = capsys.readouterr() assert "WARNING" in captured.err assert "UNKNOWN_FEATURE_XYZ" in captured.err assert "Valid names are:" in captured.err + def test_single_disable_feature(self): + """Single feature name is disabled correctly.""" + # First enable a feature + _apply_feature_overrides(enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + # Then disable it + _apply_feature_overrides(disable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) + assert not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + def test_comma_separated_disable_features(self): + """Comma-separated feature names are disabled correctly.""" + # First enable features + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING",) + ) + + # Then disable them + _apply_feature_overrides( + disable_features=( + "JSON_SCHEMA_FOR_FUNC_DECL,PROGRESSIVE_SSE_STREAMING", + ) + ) + assert not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert not is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + + def test_disable_overrides_enable(self): + """Disable is applied after enable, so disable wins for same feature.""" + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",), + disable_features=("JSON_SCHEMA_FOR_FUNC_DECL",), + ) + # disable_features is processed after enable_features + assert not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + + def test_enable_and_disable_different_features(self): + """Enable and disable can be used together for different features.""" + # First enable a feature that we'll disable + _apply_feature_overrides(enable_features=("PROGRESSIVE_SSE_STREAMING",)) + + _apply_feature_overrides( + enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",), + disable_features=("PROGRESSIVE_SSE_STREAMING",), + ) + assert is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + assert not is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + class TestFeatureOptionsDecorator: """Tests for feature_options decorator.""" @@ -195,3 +244,64 @@ def my_test_command(): "my_test_command" in my_test_command.name or my_test_command.callback.__name__ == "my_test_command" ) + + def test_decorator_adds_disable_features_option(self): + """Decorator adds --disable_features option to command.""" + + @click.command() + @feature_options() + def test_cmd(): + pass + + runner = CliRunner() + result = runner.invoke(test_cmd, ["--help"]) + assert "--disable_features" in result.output + + def test_disable_features_applied_before_command(self): + """Features are disabled before the command function runs.""" + # First enable the feature via override + _apply_feature_overrides(enable_features=("JSON_SCHEMA_FOR_FUNC_DECL",)) + + feature_was_disabled = [] + + @click.command() + @feature_options() + def test_cmd(): + feature_was_disabled.append( + not is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + ["--disable_features=JSON_SCHEMA_FOR_FUNC_DECL"], + catch_exceptions=False, + ) + assert feature_was_disabled == [True] + + def test_enable_and_disable_together(self): + """Both --enable_features and --disable_features work together.""" + feature_states = [] + + @click.command() + @feature_options() + def test_cmd(): + feature_states.append( + is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL) + ) + feature_states.append( + is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING) + ) + + runner = CliRunner() + runner.invoke( + test_cmd, + [ + "--enable_features=JSON_SCHEMA_FOR_FUNC_DECL", + "--disable_features=PROGRESSIVE_SSE_STREAMING", + ], + catch_exceptions=False, + ) + # JSON_SCHEMA_FOR_FUNC_DECL should be enabled + # PROGRESSIVE_SSE_STREAMING should be disabled + assert feature_states == [True, False] diff --git a/tests/unittests/cli/test_cli_tools_click_option_mismatch.py b/tests/unittests/cli/test_cli_tools_click_option_mismatch.py index 3c67e9ae39..9b81345312 100644 --- a/tests/unittests/cli/test_cli_tools_click_option_mismatch.py +++ b/tests/unittests/cli/test_cli_tools_click_option_mismatch.py @@ -95,7 +95,10 @@ def test_adk_run(): assert run_command is not None, "Run command not found" _check_options_in_parameters( - run_command, cli_run.callback, "run", ignore_params={"enable_features"} + run_command, + cli_run.callback, + "run", + ignore_params={"enable_features", "disable_features"}, ) @@ -105,7 +108,10 @@ def test_adk_eval(): assert eval_command is not None, "Eval command not found" _check_options_in_parameters( - eval_command, cli_eval.callback, "eval", ignore_params={"enable_features"} + eval_command, + cli_eval.callback, + "eval", + ignore_params={"enable_features", "disable_features"}, ) @@ -118,7 +124,7 @@ def test_adk_web(): web_command, cli_web.callback, "web", - ignore_params={"verbose", "enable_features"}, + ignore_params={"verbose", "enable_features", "disable_features"}, ) @@ -131,7 +137,7 @@ def test_adk_api_server(): api_server_command, cli_api_server.callback, "api_server", - ignore_params={"verbose", "enable_features"}, + ignore_params={"verbose", "enable_features", "disable_features"}, ) From 5257869d91a77ebd1381538a85e7fdc3a600da90 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 20 Jan 2026 12:26:48 -0800 Subject: [PATCH 2/5] fix: Redact sensitive information from URIs in logs This change introduces a helper function `_redact_uri_for_log` to sanitize URIs before logging. It removes user credentials from the netloc and redacts the values of query parameters, ensuring that sensitive information like passwords is not exposed in log outputs. The function is applied to all log statements and error messages that include service URIs for session, memory, and artifact services Co-authored-by: George Weale PiperOrigin-RevId: 858703465 --- src/google/adk/cli/utils/service_factory.py | 64 ++++++++++-- .../cli/utils/test_service_factory.py | 97 +++++++++++++++++-- 2 files changed, 147 insertions(+), 14 deletions(-) diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py index c03ac10b85..d8903ece14 100644 --- a/src/google/adk/cli/utils/service_factory.py +++ b/src/google/adk/cli/utils/service_factory.py @@ -19,6 +19,9 @@ from pathlib import Path from typing import Any from typing import Optional +from urllib.parse import parse_qsl +from urllib.parse import urlsplit +from urllib.parse import urlunsplit from ...artifacts.base_artifact_service import BaseArtifactService from ...memory.base_memory_service import BaseMemoryService @@ -42,6 +45,41 @@ _KUBERNETES_HOST_ENV = "KUBERNETES_SERVICE_HOST" +def _redact_uri_for_log(uri: str) -> str: + """Returns a safe-to-log representation of a URI. + + Redacts user info (username/password) and query parameter values. + """ + if not uri or not uri.strip(): + return "" + sanitized = uri.replace("\r", "\\r").replace("\n", "\\n") + if "://" not in sanitized: + return "" + try: + parsed = urlsplit(sanitized) + except ValueError: + return "" + + if not parsed.scheme: + return "" + + netloc = parsed.netloc + if "@" in netloc: + _, netloc = netloc.rsplit("@", 1) + + if parsed.query: + try: + redacted_pairs = parse_qsl(parsed.query, keep_blank_values=True) + except ValueError: + query = "" + else: + query = "&".join(f"{key}=" for key, _ in redacted_pairs) + else: + query = "" + + return urlunsplit((parsed.scheme, netloc, parsed.path, query, "")) + + def _is_cloud_run() -> bool: """Returns True when running in Cloud Run.""" return bool(os.environ.get(_CLOUD_RUN_SERVICE_ENV)) @@ -148,7 +186,10 @@ def create_session_service_from_options( kwargs.update(session_db_kwargs) if session_service_uri: - logger.info("Using session service URI: %s", session_service_uri) + logger.info( + "Using session service URI: %s", + _redact_uri_for_log(session_service_uri), + ) service = registry.create_session_service(session_service_uri, **kwargs) if service is not None: return service @@ -162,7 +203,7 @@ def create_session_service_from_options( fallback_kwargs.pop("agents_dir", None) logger.info( "Falling back to DatabaseSessionService for URI: %s", - session_service_uri, + _redact_uri_for_log(session_service_uri), ) return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) @@ -208,13 +249,18 @@ def create_memory_service_from_options( registry = get_service_registry() if memory_service_uri: - logger.info("Using memory service URI: %s", memory_service_uri) + logger.info( + "Using memory service URI: %s", _redact_uri_for_log(memory_service_uri) + ) service = registry.create_memory_service( memory_service_uri, agents_dir=str(base_path), ) if service is None: - raise ValueError(f"Unsupported memory service URI: {memory_service_uri}") + raise ValueError( + "Unsupported memory service URI: %s" + % _redact_uri_for_log(memory_service_uri) + ) return service logger.info("Using in-memory memory service") @@ -235,7 +281,10 @@ def create_artifact_service_from_options( registry = get_service_registry() if artifact_service_uri: - logger.info("Using artifact service URI: %s", artifact_service_uri) + logger.info( + "Using artifact service URI: %s", + _redact_uri_for_log(artifact_service_uri), + ) service = registry.create_artifact_service( artifact_service_uri, agents_dir=str(base_path), @@ -243,11 +292,12 @@ def create_artifact_service_from_options( if service is None: if strict_uri: raise ValueError( - f"Unsupported artifact service URI: {artifact_service_uri}" + "Unsupported artifact service URI: %s" + % _redact_uri_for_log(artifact_service_uri) ) return _create_in_memory_artifact_service( "Unsupported artifact service URI: %s, falling back to in-memory", - artifact_service_uri, + _redact_uri_for_log(artifact_service_uri), ) return service diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 87b567be73..ad9a238985 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -16,12 +16,14 @@ from __future__ import annotations +import logging import os from pathlib import Path -from unittest.mock import Mock +from unittest import mock from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.service_registry import ServiceRegistry from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService import google.adk.cli.utils.service_factory as service_factory from google.adk.memory.in_memory_memory_service import InMemoryMemoryService @@ -31,7 +33,7 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_session_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -48,6 +50,87 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): ) +def test_create_session_service_logs_redacted_uri( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) + registry.create_session_service.return_value = object() + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + session_service_uri = ( + "postgresql://user:supersecret@localhost:5432/dbname?sslmode=require" + ) + caplog.set_level(logging.INFO, logger=service_factory.logger.name) + + service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri=session_service_uri, + ) + + assert "supersecret" not in caplog.text + assert "sslmode=require" not in caplog.text + assert "localhost:5432" in caplog.text + + +def test_redact_uri_for_log_removes_credentials_with_at_in_password() -> None: + uri = "postgresql://user:super@secret@localhost:5432/dbname" + + assert ( + service_factory._redact_uri_for_log(uri) + == "postgresql://localhost:5432/dbname" + ) + + +def test_redact_uri_for_log_preserves_host_when_no_credentials() -> None: + uri = "postgresql://localhost:5432/dbname?sslmode=require&password=secret" + + redacted = service_factory._redact_uri_for_log(uri) + + assert redacted.startswith("postgresql://localhost:5432/dbname?") + assert "require" not in redacted + assert "secret" not in redacted + assert "sslmode=" in redacted + assert "password=" in redacted + + +def test_redact_uri_for_log_redacts_when_parse_qsl_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _raise_value_error(*_args, **_kwargs): + raise ValueError("bad query") + + monkeypatch.setattr(service_factory, "parse_qsl", _raise_value_error) + + uri = "postgresql://user:pass@localhost:5432/dbname?sslmode=require" + redacted = service_factory._redact_uri_for_log(uri) + + assert "pass" not in redacted + assert "require" not in redacted + assert redacted.endswith("?") + + +def test_redact_uri_for_log_escapes_crlf() -> None: + uri = ( + "postgresql://user:pass@localhost:5432/dbname\rINJECT\nINJECT" + "?sslmode=require" + ) + + redacted = service_factory._redact_uri_for_log(uri) + + assert "\r" not in redacted + assert "\n" not in redacted + assert "\\rINJECT\\nINJECT" in redacted + + +def test_redact_uri_for_log_returns_scheme_missing_without_separator() -> None: + assert ( + service_factory._redact_uri_for_log("user:pass@localhost:5432/dbname") + == "" + ) + + @pytest.mark.asyncio async def test_create_session_service_defaults_to_per_agent_sqlite( tmp_path: Path, @@ -88,7 +171,7 @@ async def test_create_session_service_respects_app_name_mapping( def test_create_session_service_fallbacks_to_database( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_session_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -109,7 +192,7 @@ def test_create_session_service_fallbacks_to_database( def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_artifact_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -129,7 +212,7 @@ def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): def test_create_artifact_service_raises_on_unknown_scheme_when_strict( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_artifact_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -142,7 +225,7 @@ def test_create_artifact_service_raises_on_unknown_scheme_when_strict( def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_memory_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -170,7 +253,7 @@ def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): def test_create_memory_service_raises_on_unknown_scheme( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_memory_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) From 135f7633253f6a415302142abc3579b664601d5b Mon Sep 17 00:00:00 2001 From: Lusha Wang Date: Tue, 20 Jan 2026 12:36:22 -0800 Subject: [PATCH 3/5] feat: Remove @experimental decorator from AgentEngineSandboxCodeExecutor Co-authored-by: Lusha Wang PiperOrigin-RevId: 858706929 --- .../adk/code_executors/agent_engine_sandbox_code_executor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index 2e3e978bc7..bfe3497c86 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -23,7 +23,6 @@ from typing_extensions import override from ..agents.invocation_context import InvocationContext -from ..utils.feature_decorator import experimental from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult @@ -32,7 +31,6 @@ logger = logging.getLogger('google_adk.' + __name__) -@experimental class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): """A code executor that uses Agent Engine Code Execution Sandbox to execute code. From 215c2f506e21a3d8c39551b80f6356943ecae320 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 20 Jan 2026 13:16:05 -0800 Subject: [PATCH 4/5] fix: Set LITELLM_MODE to PRODUCTION before importing LiteLLM LiteLLM defaults to DEV mode, which automatically loads environment variables from a local `.env` file. This change sets LITELLM_MODE to PRODUCTION to prevent LiteLLM from implicitly loading `.env` files when used within ADK. Co-authored-by: George Weale PiperOrigin-RevId: 858723362 --- src/google/adk/models/lite_llm.py | 120 ++++++++++++++---- tests/unittests/models/test_litellm_import.py | 108 ++++++++++++++++ 2 files changed, 205 insertions(+), 23 deletions(-) create mode 100644 tests/unittests/models/test_litellm_import.py diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index f6705c1de9..a7899e788d 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -16,6 +16,7 @@ import base64 import copy +import importlib.util import json import logging import mimetypes @@ -32,6 +33,7 @@ from typing import Literal from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING from typing import TypedDict from typing import Union from urllib.parse import urlparse @@ -39,20 +41,12 @@ import warnings from google.genai import types -import litellm -from litellm import acompletion -from litellm import ChatCompletionAssistantMessage -from litellm import ChatCompletionAssistantToolCall -from litellm import ChatCompletionMessageToolCall -from litellm import ChatCompletionSystemMessage -from litellm import ChatCompletionToolMessage -from litellm import ChatCompletionUserMessage -from litellm import completion -from litellm import CustomStreamWrapper -from litellm import Function -from litellm import Message -from litellm import ModelResponse -from litellm import OpenAIMessageContent + +if not TYPE_CHECKING and importlib.util.find_spec("litellm") is None: + raise ImportError( + "LiteLLM support requires: pip install google-adk[extensions]" + ) + from pydantic import BaseModel from pydantic import Field from typing_extensions import override @@ -61,8 +55,36 @@ from .llm_request import LlmRequest from .llm_response import LlmResponse -# This will add functions to prompts if functions are provided. -litellm.add_function_to_prompt = True +if TYPE_CHECKING: + import litellm + from litellm import acompletion + from litellm import ChatCompletionAssistantMessage + from litellm import ChatCompletionAssistantToolCall + from litellm import ChatCompletionMessageToolCall + from litellm import ChatCompletionSystemMessage + from litellm import ChatCompletionToolMessage + from litellm import ChatCompletionUserMessage + from litellm import completion + from litellm import CustomStreamWrapper + from litellm import Function + from litellm import Message + from litellm import ModelResponse + from litellm import OpenAIMessageContent +else: + litellm = None + acompletion = None + ChatCompletionAssistantMessage = None + ChatCompletionAssistantToolCall = None + ChatCompletionMessageToolCall = None + ChatCompletionSystemMessage = None + ChatCompletionToolMessage = None + ChatCompletionUserMessage = None + completion = None + CustomStreamWrapper = None + Function = None + Message = None + ModelResponse = None + OpenAIMessageContent = None logger = logging.getLogger("google_adk." + __name__) @@ -109,6 +131,50 @@ "before a response was recorded)." ) +_LITELLM_IMPORTED = False +_LITELLM_GLOBAL_SYMBOLS = ( + "ChatCompletionAssistantMessage", + "ChatCompletionAssistantToolCall", + "ChatCompletionMessageToolCall", + "ChatCompletionSystemMessage", + "ChatCompletionToolMessage", + "ChatCompletionUserMessage", + "CustomStreamWrapper", + "Function", + "Message", + "ModelResponse", + "OpenAIMessageContent", + "acompletion", + "completion", +) + + +def _ensure_litellm_imported() -> None: + """Imports LiteLLM with safe defaults. + + LiteLLM defaults to DEV mode, which auto-loads a local `.env` at import time. + ADK should not implicitly load `.env` just because LiteLLM is installed. + + Users can opt into LiteLLM's default behavior by setting LITELLM_MODE=DEV. + """ + global _LITELLM_IMPORTED + if _LITELLM_IMPORTED: + return + + # https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py#L80-L82 + os.environ.setdefault("LITELLM_MODE", "PRODUCTION") + + import litellm as litellm_module + + litellm_module.add_function_to_prompt = True + + globals()["litellm"] = litellm_module + for symbol in _LITELLM_GLOBAL_SYMBOLS: + globals()[symbol] = getattr(litellm_module, symbol) + + _redirect_litellm_loggers_to_stdout() + _LITELLM_IMPORTED = True + def _map_finish_reason( finish_reason: Any, @@ -344,6 +410,7 @@ async def acompletion( Returns: The model response as a message. """ + _ensure_litellm_imported() return await acompletion( model=model, @@ -367,6 +434,7 @@ def completion( Returns: The response from the model. """ + _ensure_litellm_imported() return completion( model=model, @@ -513,6 +581,7 @@ async def _content_to_message_param( Returns: A litellm Message, a list of litellm Messages. """ + _ensure_litellm_imported() tool_messages: list[Message] = [] non_tool_parts: list[types.Part] = [] @@ -622,6 +691,8 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: if not messages: return messages + _ensure_litellm_imported() + healed_messages: List[Message] = [] pending_tool_call_ids: List[str] = [] @@ -691,6 +762,7 @@ async def _get_content( Returns: The litellm content. """ + _ensure_litellm_imported() parts_list = list(parts) if len(parts_list) == 1: @@ -925,6 +997,7 @@ def _build_tool_call_from_json_dict( candidate: Any, *, index: int ) -> Optional[ChatCompletionMessageToolCall]: """Creates a tool call object from JSON content embedded in text.""" + _ensure_litellm_imported() if not isinstance(candidate, dict): return None @@ -972,11 +1045,12 @@ def _parse_tool_calls_from_text( text_block: str, ) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: """Extracts inline JSON tool calls from LiteLLM text responses.""" - tool_calls = [] if not text_block: return tool_calls, None + _ensure_litellm_imported() + remainder_segments = [] cursor = 0 text_length = len(text_block) @@ -1014,7 +1088,6 @@ def _split_message_content_and_tool_calls( message: Message, ) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]: """Returns message content and tool calls, parsing inline JSON when needed.""" - existing_tool_calls = message.get("tool_calls") or [] normalized_tool_calls = ( list(existing_tool_calls) if existing_tool_calls else [] @@ -1180,6 +1253,7 @@ def _model_response_to_chunk( Yields: A tuple of text or function or usage metadata chunk and finish reason. """ + _ensure_litellm_imported() message = None if response.get("choices", None): @@ -1255,6 +1329,7 @@ def _model_response_to_generate_content_response( Returns: The LlmResponse. """ + _ensure_litellm_imported() message = None finish_reason = None @@ -1313,6 +1388,7 @@ def _message_to_generate_content_response( Returns: The LlmResponse. """ + _ensure_litellm_imported() parts: List[types.Part] = [] if not thought_parts: @@ -1440,6 +1516,8 @@ async def _get_completion_inputs( The litellm inputs (message list, tool dictionary, response format and generation params). """ + _ensure_litellm_imported() + # Determine provider for file handling provider = _get_provider_from_model(model) @@ -1665,11 +1743,6 @@ def _redirect_litellm_loggers_to_stdout() -> None: handler.stream = sys.stdout -# Redirect LiteLLM loggers to stdout immediately after import to ensure -# INFO-level logs are not incorrectly treated as errors in cloud environments. -_redirect_litellm_loggers_to_stdout() - - class LiteLlm(BaseLlm): """Wrapper around litellm. @@ -1732,6 +1805,7 @@ async def generate_content_async( Yields: LlmResponse: The model response. """ + _ensure_litellm_imported() self._maybe_append_user_content(llm_request) _append_fallback_user_content_if_missing(llm_request) diff --git a/tests/unittests/models/test_litellm_import.py b/tests/unittests/models/test_litellm_import.py new file mode 100644 index 0000000000..179dd4703e --- /dev/null +++ b/tests/unittests/models/test_litellm_import.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import os +import subprocess +import sys + +import pytest + + +def _subprocess_env() -> dict[str, str]: + env = dict(os.environ) + src_path = os.path.join(os.getcwd(), "src") + pythonpath = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = ( + f"{src_path}{os.pathsep}{pythonpath}" if pythonpath else src_path + ) + return env + + +def test_importing_models_does_not_import_litellm_or_set_mode(): + env = _subprocess_env() + env.pop("LITELLM_MODE", None) + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import os, sys\n" + "import google.adk.models\n" + "print('litellm' in sys.modules)\n" + "print(os.environ.get('LITELLM_MODE'))\n" + ), + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + stdout_lines = result.stdout.strip().splitlines() + assert stdout_lines == ["False", "None"] + + +def test_ensure_litellm_imported_defaults_to_production(): + if importlib.util.find_spec("litellm") is None: + pytest.skip("litellm is not installed") + + env = _subprocess_env() + env.pop("LITELLM_MODE", None) + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import os\n" + "from google.adk.models.lite_llm import" + " _ensure_litellm_imported\n" + "_ensure_litellm_imported()\n" + "print(os.environ.get('LITELLM_MODE'))\n" + ), + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + assert result.stdout.strip() == "PRODUCTION" + + +def test_ensure_litellm_imported_does_not_override(): + if importlib.util.find_spec("litellm") is None: + pytest.skip("litellm is not installed") + + env = _subprocess_env() + env["LITELLM_MODE"] = "DEV" + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import os\n" + "from google.adk.models.lite_llm import" + " _ensure_litellm_imported\n" + "_ensure_litellm_imported()\n" + "print(os.environ.get('LITELLM_MODE'))\n" + ), + ], + check=True, + capture_output=True, + text=True, + env=env, + ) + assert result.stdout.strip() == "DEV" From 7955177fb28b8e5dc19aae8be94015a7b5d9882a Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 20 Jan 2026 13:28:23 -0800 Subject: [PATCH 5/5] fix: Update dependency versions in pyproject.toml Bump authlib to >=1.6.6 and mcp to >=1.23.0. Co-authored-by: George Weale PiperOrigin-RevId: 858728187 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1af967046f..ad1da6fff1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "PyYAML>=6.0.2, <7.0.0", # For APIHubToolset. "aiosqlite>=0.21.0", # For SQLite database "anyio>=4.9.0, <5.0.0", # For MCP Session Manager - "authlib>=1.5.1, <2.0.0", # For RestAPI Tool + "authlib>=1.6.6, <2.0.0", # For RestAPI Tool "click>=8.1.8, <9.0.0", # For CLI tools "fastapi>=0.115.0, <0.124.0", # FastAPI framework "google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery @@ -46,7 +46,7 @@ dependencies = [ "google-genai>=1.56.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation - "mcp>=1.10.0, <2.0.0", # For MCP Toolset + "mcp>=1.23.0, <2.0.0", # For MCP Toolset "opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package. "opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0",