Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,20 @@ def _message_to_generate_content_response(

def _to_litellm_response_format(
response_schema: types.SchemaUnion,
) -> Optional[Dict[str, Any]]:
"""Converts ADK response schema objects into LiteLLM-compatible payloads."""
model: str,
) -> dict[str, Any] | None:
"""Converts ADK response schema objects into LiteLLM-compatible payloads.

Args:
response_schema: The response schema to convert.
model: The model string to determine the appropriate format. Gemini models
use 'response_schema' key, while OpenAI-compatible models use
'json_schema' key.

Returns:
A dictionary with the appropriate response format for LiteLLM.
"""
schema_name = "response"

if isinstance(response_schema, dict):
schema_type = response_schema.get("type")
Expand All @@ -993,33 +1005,63 @@ def _to_litellm_response_format(
):
return response_schema
schema_dict = dict(response_schema)
if "title" in schema_dict:
schema_name = str(schema_dict["title"])
elif isinstance(response_schema, type) and issubclass(
response_schema, BaseModel
):
schema_dict = response_schema.model_json_schema()
schema_name = response_schema.__name__
elif isinstance(response_schema, BaseModel):
if isinstance(response_schema, types.Schema):
# GenAI Schema instances already represent JSON schema definitions.
schema_dict = response_schema.model_dump(exclude_none=True, mode="json")
if "title" in schema_dict:
schema_name = str(schema_dict["title"])
else:
schema_dict = response_schema.__class__.model_json_schema()
schema_name = response_schema.__class__.__name__
elif hasattr(response_schema, "model_dump"):
schema_dict = response_schema.model_dump(exclude_none=True, mode="json")
schema_name = response_schema.__class__.__name__
else:
logger.warning(
"Unsupported response_schema type %s for LiteLLM structured outputs.",
type(response_schema),
)
return None

# Gemini models use a special response format with 'response_schema' key
if _is_litellm_gemini_model(model):
return {
"type": "json_object",
"response_schema": schema_dict,
}

# OpenAI-compatible format (default) per LiteLLM docs:
# https://docs.litellm.ai/docs/completion/json_mode
if (
isinstance(schema_dict, dict)
and schema_dict.get("type") == "object"
and "additionalProperties" not in schema_dict
):
# OpenAI structured outputs require explicit additionalProperties: false.
schema_dict = dict(schema_dict)
schema_dict["additionalProperties"] = False

return {
"type": "json_object",
"response_schema": schema_dict,
"type": "json_schema",
"json_schema": {
"name": schema_name,
"strict": True,
"schema": schema_dict,
},
}


async def _get_completion_inputs(
llm_request: LlmRequest,
model: str,
) -> Tuple[
List[Message],
Optional[List[Dict]],
Expand All @@ -1030,13 +1072,14 @@ async def _get_completion_inputs(

Args:
llm_request: The LlmRequest to convert.
model: The model string to use for determining provider-specific behavior.

Returns:
The litellm inputs (message list, tool dictionary, response format and
generation params).
"""
# Determine provider for file handling
provider = _get_provider_from_model(llm_request.model or "")
provider = _get_provider_from_model(model)

# 1. Construct messages
messages: List[Message] = []
Expand Down Expand Up @@ -1071,14 +1114,15 @@ async def _get_completion_inputs(
]

# 3. Handle response format
response_format: Optional[Dict[str, Any]] = None
response_format: dict[str, Any] | None = None
if llm_request.config and llm_request.config.response_schema:
response_format = _to_litellm_response_format(
llm_request.config.response_schema
llm_request.config.response_schema,
model=model,
)

# 4. Extract generation parameters
generation_params: Optional[Dict] = None
generation_params: dict | None = None
if llm_request.config:
config_dict = llm_request.config.model_dump(exclude_none=True)
# Generate LiteLlm parameters here,
Expand Down Expand Up @@ -1190,9 +1234,7 @@ def _is_litellm_gemini_model(model_string: str) -> bool:
Returns:
True if it's a Gemini model accessed via LiteLLM, False otherwise
"""
# Matches "gemini/gemini-*" (Google AI Studio) or "vertex_ai/gemini-*" (Vertex AI).
pattern = r"^(gemini|vertex_ai)/gemini-"
return bool(re.match(pattern, model_string))
return model_string.startswith(("gemini/gemini-", "vertex_ai/gemini-"))


def _extract_gemini_model_from_litellm(litellm_model: str) -> str:
Expand Down Expand Up @@ -1308,16 +1350,17 @@ async def generate_content_async(
_append_fallback_user_content_if_missing(llm_request)
logger.debug(_build_request_log(llm_request))

model = llm_request.model or self.model
messages, tools, response_format, generation_params = (
await _get_completion_inputs(llm_request)
await _get_completion_inputs(llm_request, model)
)

if "functions" in self._additional_args:
# LiteLLM does not support both tools and functions together.
tools = None

completion_args = {
"model": llm_request.model or self.model,
"model": model,
"messages": messages,
"tools": tools,
"response_format": response_format,
Expand Down
17 changes: 17 additions & 0 deletions src/google/adk/tools/_automatic_function_calling_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from pydantic import fields as pydantic_fields

from . import _function_parameter_parse_util
from . import _function_tool_declarations
from ..features import FeatureName
from ..features import is_feature_enabled
from ..utils.variant_utils import GoogleLLMVariant

_py_type_2_schema_type = {
Expand Down Expand Up @@ -196,6 +199,20 @@ def build_function_declaration(
ignore_params: Optional[list[str]] = None,
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
) -> types.FunctionDeclaration:
# ========== Pydantic-based function tool declaration (new feature) ==========
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
declaration = (
_function_tool_declarations.build_function_declaration_with_json_schema(
func, ignore_params=ignore_params
)
)
# Add response schema only for VERTEX_AI
# TODO(b/421991354): Remove this check once the bug is fixed.
if variant != GoogleLLMVariant.VERTEX_AI:
declaration.response_json_schema = None
return declaration

# ========== ADK defined function tool declaration (old behavior) ==========
signature = inspect.signature(func)
should_update_signature = False
new_func = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import ssl
from typing import Any
from typing import Callable
from typing import Dict
from typing import Final
from typing import List
Expand Down Expand Up @@ -71,6 +72,9 @@ def __init__(
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
tool_name_prefix: Optional[str] = None,
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
):
"""Initializes the OpenAPIToolset.

Expand Down Expand Up @@ -116,8 +120,14 @@ def __init__(
- ssl.SSLContext: Custom SSL context for advanced configuration
This is useful for enterprise environments where requests go through
a TLS-intercepting proxy with a custom CA certificate.
header_provider: A callable that returns a dictionary of headers to be
included in API requests. The callable receives the ReadonlyContext as
an argument, allowing dynamic header generation based on the current
context. Useful for adding custom headers like correlation IDs,
authentication tokens, or other request metadata.
"""
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
self._header_provider = header_provider
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self._ssl_verify = ssl_verify
Expand Down Expand Up @@ -189,7 +199,11 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:

tools = []
for o in operations:
tool = RestApiTool.from_parsed_operation(o, ssl_verify=self._ssl_verify)
tool = RestApiTool.from_parsed_operation(
o,
ssl_verify=self._ssl_verify,
header_provider=self._header_provider,
)
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
return tools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ssl
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal
Expand All @@ -29,6 +30,7 @@
import requests
from typing_extensions import override

from ....agents.readonly_context import ReadonlyContext
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ..._gemini_schema_util import _to_gemini_schema
Expand Down Expand Up @@ -90,6 +92,9 @@ def __init__(
auth_credential: Optional[Union[AuthCredential, str]] = None,
should_parse_operation=True,
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
):
"""Initializes the RestApiTool with the given parameters.

Expand Down Expand Up @@ -122,6 +127,11 @@ def __init__(
- False: Disable SSL verification (insecure, not recommended)
- str: Path to a CA bundle file or directory for custom CA
- ssl.SSLContext: Custom SSL context for advanced configuration
header_provider: A callable that returns a dictionary of headers to be
included in API requests. The callable receives the ReadonlyContext as
an argument, allowing dynamic header generation based on the current
context. Useful for adding custom headers like correlation IDs,
authentication tokens, or other request metadata.
"""
# Gemini restrict the length of function name to be less than 64 characters
self.name = name[:60]
Expand All @@ -145,6 +155,7 @@ def __init__(
self.credential_exchanger = AutoAuthCredentialExchanger()
self._default_headers: Dict[str, str] = {}
self._ssl_verify = ssl_verify
self._header_provider = header_provider
if should_parse_operation:
self._operation_parser = OperationParser(self.operation)

Expand All @@ -153,12 +164,20 @@ def from_parsed_operation(
cls,
parsed: ParsedOperation,
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
) -> "RestApiTool":
"""Initializes the RestApiTool from a ParsedOperation object.

Args:
parsed: A ParsedOperation object.
ssl_verify: SSL certificate verification option.
header_provider: A callable that returns a dictionary of headers to be
included in API requests. The callable receives the ReadonlyContext as
an argument, allowing dynamic header generation based on the current
context. Useful for adding custom headers like correlation IDs,
authentication tokens, or other request metadata.

Returns:
A RestApiTool object.
Expand All @@ -178,6 +197,7 @@ def from_parsed_operation(
auth_scheme=parsed.auth_scheme,
auth_credential=parsed.auth_credential,
ssl_verify=ssl_verify,
header_provider=header_provider,
)
generated._operation_parser = operation_parser
return generated
Expand Down Expand Up @@ -450,6 +470,13 @@ async def call(
request_params = self._prepare_request_params(api_params, api_args)
if self._ssl_verify is not None:
request_params["verify"] = self._ssl_verify

# Add headers from header_provider if configured
if self._header_provider is not None and tool_context is not None:
provider_headers = self._header_provider(tool_context)
if provider_headers:
request_params.setdefault("headers", {}).update(provider_headers)

response = requests.request(**request_params)

# Parse API response
Expand Down
Loading
Loading