Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"click>=8.1.8, <9.0.0", # For CLI tools
"fastapi>=0.115.0, <0.124.0", # FastAPI framework
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
"google-auth>=2.47.0", # Google Auth library
"google-cloud-aiplatform[agent_engines]>=1.132.0, <2.0.0", # For VertexAI integrations, e.g. example store.
"google-cloud-bigquery-storage>=2.0.0",
"google-cloud-bigquery>=2.2.0",
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/agents/config_schemas/AgentConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -2461,7 +2461,7 @@
}
],
"default": null,
"description": "Optional. LlmAgent.model. If not set, the model will be inherited from the ancestor.",
"description": "Optional. LlmAgent.model. Provide a model name string (e.g. \"gemini-2.0-flash\"). If not set, the model will be inherited from the ancestor or fall back to the system default (gemini-2.5-flash unless overridden via LlmAgent.set_default_model). To construct a model instance from code, use model_code.",
"title": "Model"
},
"instruction": {
Expand Down Expand Up @@ -4601,4 +4601,4 @@
}
],
"title": "AgentConfig"
}
}
9 changes: 0 additions & 9 deletions src/google/adk/agents/live_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ class LiveRequestQueue:
"""Queue used to send LiveRequest in a live(bidirectional streaming) way."""

def __init__(self):
# Ensure there's an event loop available in this thread
try:
asyncio.get_running_loop()
except RuntimeError:
# No running loop, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Now create the queue (it will use the event loop we just ensured exists)
self._queue = asyncio.Queue()

def close(self):
Expand Down
32 changes: 29 additions & 3 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,18 @@ async def _convert_tool_union_to_tools(
class LlmAgent(BaseAgent):
"""LLM-based Agent."""

DEFAULT_MODEL: ClassVar[str] = 'gemini-2.5-flash'
"""System default model used when no model is set on an agent."""

_default_model: ClassVar[Union[str, BaseLlm]] = DEFAULT_MODEL
"""Current default model used when an agent has no model set."""

model: Union[str, BaseLlm] = ''
"""The model to use for the agent.
When not set, the agent will inherit the model from its ancestor.
When not set, the agent will inherit the model from its ancestor. If no
ancestor provides a model, the agent uses the default model configured via
LlmAgent.set_default_model. The built-in default is gemini-2.5-flash.
"""

config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig
Expand Down Expand Up @@ -503,7 +511,24 @@ def canonical_model(self) -> BaseLlm:
if isinstance(ancestor_agent, LlmAgent):
return ancestor_agent.canonical_model
ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.')
return self._resolve_default_model()

@classmethod
def set_default_model(cls, model: Union[str, BaseLlm]) -> None:
"""Overrides the default model used when an agent has no model set."""
if not isinstance(model, (str, BaseLlm)):
raise TypeError('Default model must be a model name or BaseLlm.')
if isinstance(model, str) and not model:
raise ValueError('Default model must be a non-empty string.')
cls._default_model = model

@classmethod
def _resolve_default_model(cls) -> BaseLlm:
"""Resolves the current default model to a BaseLlm instance."""
default_model = cls._default_model
if isinstance(default_model, BaseLlm):
return default_model
return LLMRegistry.new_llm(default_model)

async def canonical_instruction(
self, ctx: ReadonlyContext
Expand Down Expand Up @@ -575,10 +600,11 @@ async def canonical_tools(
# because the built-in tools cannot be used together with other tools.
# TODO(b/448114567): Remove once the workaround is no longer needed.
multiple_tools = len(self.tools) > 1
model = self.canonical_model
for tool_union in self.tools:
resolved_tools.extend(
await _convert_tool_union_to_tools(
tool_union, ctx, self.model, multiple_tools
tool_union, ctx, model, multiple_tools
)
)
return resolved_tools
Expand Down
5 changes: 3 additions & 2 deletions src/google/adk/agents/llm_agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ class LlmAgentConfig(BaseAgentConfig):
description=(
'Optional. LlmAgent.model. Provide a model name string (e.g.'
' "gemini-2.0-flash"). If not set, the model will be inherited from'
' the ancestor. To construct a model instance from code, use'
' model_code.'
' the ancestor or fall back to the system default (gemini-2.5-flash'
' unless overridden via LlmAgent.set_default_model). To construct a'
' model instance from code, use model_code.'
),
)

Expand Down
6 changes: 1 addition & 5 deletions src/google/adk/auth/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,13 @@ async def get_auth_credential(
async def _load_existing_credential(
self, callback_context: CallbackContext
) -> Optional[AuthCredential]:
"""Load existing credential from credential service or cached exchanged credential."""
"""Load existing credential from credential service."""

# Try loading from credential service first
credential = await self._load_from_credential_service(callback_context)
if credential:
return credential

# Check if we have a cached exchanged credential
if self._auth_config.exchanged_auth_credential:
return self._auth_config.exchanged_auth_credential

return None

async def _load_from_credential_service(
Expand Down
33 changes: 25 additions & 8 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,14 +1531,31 @@ async def event_generator():
)
) as agen:
async for event in agen:
# Format as SSE data
sse_event = event.model_dump_json(
exclude_none=True, by_alias=True
)
logger.debug(
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
# ADK Web renders artifacts from `actions.artifactDelta`
# during part processing *and* during action processing
# 1) the original event with `artifactDelta` cleared (content)
# 2) a content-less "action-only" event carrying `artifactDelta`
events_to_stream = [event]
if (
event.actions.artifact_delta
and event.content
and event.content.parts
):
content_event = event.model_copy(deep=True)
content_event.actions.artifact_delta = {}
artifact_event = event.model_copy(deep=True)
artifact_event.content = None
events_to_stream = [content_event, artifact_event]

for event_to_stream in events_to_stream:
sse_event = event_to_stream.model_dump_json(
exclude_none=True,
by_alias=True,
)
logger.debug(
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
# You might want to yield an error event here
Expand Down
8 changes: 8 additions & 0 deletions src/google/adk/cli/cli_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import click

from ..apps.app import validate_app_name

_INIT_PY_TEMPLATE = """\
from . import agent
"""
Expand Down Expand Up @@ -294,6 +296,12 @@ def run_cmd(
VertexAI as backend.
type: Optional[str], Whether to define agent with config file or code.
"""
app_name = os.path.basename(os.path.normpath(agent_name))
try:
validate_app_name(app_name)
except ValueError as exc:
raise click.BadParameter(str(exc)) from exc

agent_folder = os.path.join(os.getcwd(), agent_name)
# check folder doesn't exist or it's empty. Otherwise, throw
if os.path.exists(agent_folder) and os.listdir(agent_folder):
Expand Down
42 changes: 42 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import cli_deploy
from .. import version
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
from ..sessions.migration import migration_runner
from .cli import run_cli
from .fast_api import get_fast_api_app
from .utils import envs
Expand Down Expand Up @@ -1507,6 +1508,47 @@ def cli_deploy_cloud_run(
click.secho(f"Deploy failed: {e}", fg="red", err=True)


@main.group()
def migrate():
"""ADK migration commands."""
pass


@migrate.command("session", cls=HelpfulCommand)
@click.option(
"--source_db_url",
required=True,
help=(
"SQLAlchemy URL of source database in database session service, e.g."
" sqlite:///source.db."
),
)
@click.option(
"--dest_db_url",
required=True,
help=(
"SQLAlchemy URL of destination database in database session service,"
" e.g. sqlite:///dest.db."
),
)
@click.option(
"--log_level",
type=LOG_LEVELS,
default="INFO",
help="Optional. Set the logging level",
)
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
migration_runner.upgrade(source_db_url, dest_db_url)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)


@deploy.command("agent_engine")
@click.option(
"--api_key",
Expand Down
40 changes: 38 additions & 2 deletions src/google/adk/cli/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import logging
import os

from dotenv import load_dotenv

logger = logging.getLogger(__file__)
from ...utils.env_utils import is_env_enabled

logger = logging.getLogger('google_adk.' + __name__)

_ADK_DISABLE_LOAD_DOTENV_ENV_VAR = 'ADK_DISABLE_LOAD_DOTENV'


@functools.lru_cache(maxsize=1)
def _get_explicit_env_keys() -> frozenset[str]:
"""Returns env var keys set before ADK loads any `.env` files.

This snapshot is used to preserve user-provided environment variables while
still allowing later `.env` files to override earlier ones via
`override=True`.
"""
return frozenset(os.environ)


def _walk_to_root_until_found(folder, filename) -> str:
Expand All @@ -35,15 +53,33 @@ def _walk_to_root_until_found(folder, filename) -> str:
def load_dotenv_for_agent(
agent_name: str, agent_parent_folder: str, filename: str = '.env'
):
"""Loads the .env file for the agent module."""
"""Loads the `.env` file for the agent module.

Explicit environment variables (present before the first `.env` load) are
preserved, while values loaded from `.env` may be overridden by later `.env`
loads.
"""
if is_env_enabled(_ADK_DISABLE_LOAD_DOTENV_ENV_VAR):
logger.info(
'Skipping %s loading because %s is enabled.',
filename,
_ADK_DISABLE_LOAD_DOTENV_ENV_VAR,
)
return

# Gets the folder of agent_module as starting_folder
starting_folder = os.path.abspath(
os.path.join(agent_parent_folder, agent_name)
)
dotenv_file_path = _walk_to_root_until_found(starting_folder, filename)
if dotenv_file_path:
explicit_env_keys = _get_explicit_env_keys()
explicit_env = {
key: os.environ[key] for key in explicit_env_keys if key in os.environ
}

load_dotenv(dotenv_file_path, override=True, verbose=True)
os.environ.update(explicit_env)
logger.info(
'Loaded %s file for %s at %s',
filename,
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/flows/llm_flows/_output_schema_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def run_async(
if (
not agent.output_schema
or not agent.tools
or can_use_output_schema_with_tools(agent.model)
or can_use_output_schema_with_tools(agent.canonical_model)
):
return

Expand Down
46 changes: 21 additions & 25 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
_ADK_AGENT_NAME_LABEL_KEY = 'adk_agent_name'

# Timing configuration
DEFAULT_REQUEST_QUEUE_TIMEOUT = 0.25
DEFAULT_TRANSFER_AGENT_DELAY = 1.0
DEFAULT_TASK_COMPLETION_DELAY = 1.0

Expand Down Expand Up @@ -238,29 +237,22 @@ async def _send_to_model(
"""Sends data to model."""
while True:
live_request_queue = invocation_context.live_request_queue
try:
# Streamlit's execution model doesn't preemptively yield to the event
# loop. Therefore, we must explicitly introduce timeouts to allow the
# event loop to process events.
# TODO: revert back(remove timeout) once we move off streamlit.
live_request = await asyncio.wait_for(
live_request_queue.get(), timeout=DEFAULT_REQUEST_QUEUE_TIMEOUT
)
# duplicate the live_request to all the active streams
logger.debug(
'Sending live request %s to active streams: %s',
live_request,
invocation_context.active_streaming_tools,
)
if invocation_context.active_streaming_tools:
for active_streaming_tool in (
invocation_context.active_streaming_tools
).values():
if active_streaming_tool.stream:
active_streaming_tool.stream.send(live_request)
await asyncio.sleep(0)
except asyncio.TimeoutError:
continue
live_request = await live_request_queue.get()
# duplicate the live_request to all the active streams
logger.debug(
'Sending live request %s to active streams: %s',
live_request,
invocation_context.active_streaming_tools,
)
if invocation_context.active_streaming_tools:
for active_streaming_tool in (
invocation_context.active_streaming_tools
).values():
if active_streaming_tool.stream:
active_streaming_tool.stream.send(live_request)
# Yield to event loop for cooperative multitasking
await asyncio.sleep(0)

if live_request.close:
await llm_connection.close()
return
Expand Down Expand Up @@ -484,7 +476,11 @@ async def _preprocess_async(
# We may need to wrap some built-in tools if there are other tools
# because the built-in tools cannot be used together with other tools.
# TODO(b/448114567): Remove once the workaround is no longer needed.
if not agent.tools:
return

multiple_tools = len(agent.tools) > 1
model = agent.canonical_model
for tool_union in agent.tools:
tool_context = ToolContext(invocation_context)

Expand All @@ -500,7 +496,7 @@ async def _preprocess_async(
tools = await _convert_tool_union_to_tools(
tool_union,
ReadonlyContext(invocation_context),
agent.model,
model,
multiple_tools,
)
for tool in tools:
Expand Down
Loading
Loading