Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions src/google/adk/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from __future__ import annotations

from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Union

import click
from google.genai import types
from pydantic import BaseModel

from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import LlmAgent
from ..apps.app import App
from ..artifacts.base_artifact_service import BaseArtifactService
Expand All @@ -35,8 +35,11 @@
from ..sessions.session import Session
from ..utils.context_utils import Aclosing
from ..utils.env_utils import is_env_enabled
from .service_registry import load_services_module
from .utils import envs
from .utils.agent_loader import AgentLoader
from .utils.service_factory import create_artifact_service_from_options
from .utils.service_factory import create_session_service_from_options


class InputFile(BaseModel):
Expand Down Expand Up @@ -66,7 +69,7 @@ async def run_input_file(
)
with open(input_path, 'r', encoding='utf-8') as f:
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
input_file.state['_time'] = datetime.now().isoformat()

session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
Expand Down Expand Up @@ -134,6 +137,8 @@ async def run_cli(
saved_session_file: Optional[str] = None,
save_session: bool,
session_id: Optional[str] = None,
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
) -> None:
"""Runs an interactive CLI for a certain agent.

Expand All @@ -148,24 +153,47 @@ async def run_cli(
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit.
session_id: Optional[str], the session ID to save the session to on exit.
session_service_uri: Optional[str], custom session service URI.
artifact_service_uri: Optional[str], custom artifact service URI.
"""
agent_parent_path = Path(agent_parent_dir).resolve()
agent_root = agent_parent_path / agent_folder_name
load_services_module(str(agent_root))
user_id = 'test_user'

artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
credential_service = InMemoryCredentialService()
# Create session and artifact services using factory functions
session_service = create_session_service_from_options(
base_dir=agent_root,
session_service_uri=session_service_uri,
)

user_id = 'test_user'
agent_or_app = AgentLoader(agents_dir=agent_parent_dir).load_agent(
artifact_service = create_artifact_service_from_options(
base_dir=agent_root,
artifact_service_uri=artifact_service_uri,
)

credential_service = InMemoryCredentialService()
agents_dir = str(agent_parent_path)
agent_or_app = AgentLoader(agents_dir=agents_dir).load_agent(
agent_folder_name
)
session_app_name = (
agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name
)
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)
if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'):
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
envs.load_dotenv_for_agent(agent_folder_name, agents_dir)

# Helper function for printing events
def _print_event(event) -> None:
content = event.content
if not content or not content.parts:
return
text_parts = [part.text for part in content.parts if part.text]
if not text_parts:
return
author = event.author or 'system'
click.echo(f'[{author}]: {"".join(text_parts)}')

if input_file:
session = await run_input_file(
app_name=session_app_name,
Expand All @@ -177,16 +205,22 @@ async def run_cli(
input_path=input_file,
)
elif saved_session_file:
# Load the saved session from file
with open(saved_session_file, 'r', encoding='utf-8') as f:
loaded_session = Session.model_validate_json(f.read())

# Create a new session in the service, copying state from the file
session = await session_service.create_session(
app_name=session_app_name,
user_id=user_id,
state=loaded_session.state if loaded_session else None,
)

# Append events from the file to the new session and display them
if loaded_session:
for event in loaded_session.events:
await session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
click.echo(f'[{event.author}]: {content.parts[0].text}')
_print_event(event)

await run_interactively(
agent_or_app,
Expand All @@ -196,6 +230,9 @@ async def run_cli(
credential_service,
)
else:
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)
click.echo(f'Running agent {agent_or_app.name}, type exit to exit.')
await run_interactively(
agent_or_app,
Expand All @@ -207,19 +244,17 @@ async def run_cli(

if save_session:
session_id = session_id or input('Session ID to save: ')
session_path = (
f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json'
)
session_path = agent_root / f'{session_id}.session.json'

# Fetch the session again to get all the details.
session = await session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
)
with open(session_path, 'w', encoding='utf-8') as f:
f.write(
session.model_dump_json(indent=2, exclude_none=True, by_alias=True)
)
session_path.write_text(
session.model_dump_json(indent=2, exclude_none=True, by_alias=True),
encoding='utf-8',
)

print('Session saved to', session_path)
128 changes: 74 additions & 54 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
from pathlib import Path
import tempfile
import textwrap
from typing import Optional

import click
Expand Down Expand Up @@ -354,7 +355,62 @@ def validate_exclusive(ctx, param, value):
return value


def adk_services_options():
"""Decorator to add ADK services options to click commands."""

def decorator(func):
@click.option(
"--session_service_uri",
help=textwrap.dedent(
"""\
Optional. The URI of the session service.
- Leave unset to use the in-memory session service (default).
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
sessions. <agent_engine> can either be the full qualified resource
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
the resource id '123'.
- Use 'memory://' to run with the in-memory session service.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs."""
),
)
@click.option(
"--artifact_service_uri",
type=str,
help=textwrap.dedent(
"""\
Optional. The URI of the artifact service.
- Leave unset to store artifacts under '.adk/artifacts' locally.
- Use 'gs://<bucket_name>' to connect to the GCS artifact service.
- Use 'memory://' to force the in-memory artifact service.
- Use 'file://<path>' to store artifacts in a custom local directory."""
),
default=None,
)
@click.option(
"--memory_service_uri",
type=str,
help=textwrap.dedent("""\
Optional. The URI of the memory service.
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
sessions. <agent_engine> can either be the full qualified resource
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
the resource id '123'.
- Use 'memory://' to force the in-memory memory service."""),
default=None,
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper

return decorator


@main.command("run", cls=HelpfulCommand)
@adk_services_options()
@click.option(
"--save_session",
type=bool,
Expand Down Expand Up @@ -409,6 +465,9 @@ def cli_run(
session_id: Optional[str],
replay: Optional[str],
resume: Optional[str],
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
):
"""Runs an interactive CLI for a certain agent.

Expand All @@ -420,6 +479,14 @@ def cli_run(
"""
logs.log_to_tmp_folder()

# Validation warning for memory_service_uri (not supported for adk run)
if memory_service_uri:
click.secho(
"WARNING: --memory_service_uri is not supported for adk run.",
fg="yellow",
err=True,
)

agent_parent_folder = os.path.dirname(agent)
agent_folder_name = os.path.basename(agent)

Expand All @@ -431,6 +498,8 @@ def cli_run(
saved_session_file=resume,
save_session=save_session,
session_id=session_id,
session_service_uri=session_service_uri,
artifact_service_uri=artifact_service_uri,
)
)

Expand Down Expand Up @@ -865,63 +934,14 @@ def wrapper(*args, **kwargs):
return decorator


def adk_services_options():
"""Decorator to add ADK services options to click commands."""

def decorator(func):
@click.option(
"--session_service_uri",
help=(
"""Optional. The URI of the session service.
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
sessions. <agent_engine> can either be the full qualified resource
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
the resource id '123'.
- Use 'sqlite://<path_to_sqlite_file>' to connect to an aio-sqlite
based session service, which is good for local development.
- Use 'postgresql://<user>:<password>@<host>:<port>/<database_name>'
to connect to a PostgreSQL DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls
for more details on other database URIs supported by SQLAlchemy."""
),
)
@click.option(
"--artifact_service_uri",
type=str,
help=(
"Optional. The URI of the artifact service,"
" supported URIs: gs://<bucket name> for GCS artifact service."
),
default=None,
)
@click.option(
"--memory_service_uri",
type=str,
help=("""Optional. The URI of the memory service.
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
sessions. <agent_engine> can either be the full qualified resource
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
the resource id '123'."""),
default=None,
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper

return decorator


def deprecated_adk_services_options():
"""Deprecated ADK services options."""

def warn(alternative_param, ctx, param, value):
if value:
click.echo(
click.style(
f"WARNING: Deprecated option {param.name} is used. Please use"
f"WARNING: Deprecated option --{param.name} is used. Please use"
f" {alternative_param} instead.",
fg="yellow",
),
Expand Down Expand Up @@ -1116,6 +1136,8 @@ def cli_web(

adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir
"""
session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
logs.setup_adk_logger(getattr(logging, log_level.upper()))

@asynccontextmanager
Expand All @@ -1140,8 +1162,6 @@ async def _lifespan(app: FastAPI):
fg="green",
)

session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
app = get_fast_api_app(
agents_dir=agents_dir,
session_service_uri=session_service_uri,
Expand Down Expand Up @@ -1215,10 +1235,10 @@ def cli_api_server(

adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir
"""
logs.setup_adk_logger(getattr(logging, log_level.upper()))

session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
logs.setup_adk_logger(getattr(logging, log_level.upper()))

config = uvicorn.Config(
get_fast_api_app(
agents_dir=agents_dir,
Expand Down
Loading
Loading