diff --git a/contributing/samples/adk_stale_agent/agent.py b/contributing/samples/adk_stale_agent/agent.py index 2535f9cf45..e9fbe49bdf 100644 --- a/contributing/samples/adk_stale_agent/agent.py +++ b/contributing/samples/adk_stale_agent/agent.py @@ -49,6 +49,7 @@ BOT_ALERT_SIGNATURE = ( "**Notification:** The author has updated the issue description" ) +BOT_NAME = "adk-bot" # --- Global Cache --- _MAINTAINERS_CACHE: Optional[List[str]] = None @@ -246,8 +247,9 @@ def _build_history_timeline( if BOT_ALERT_SIGNATURE in c_body: if last_bot_alert_time is None or c_time > last_bot_alert_time: last_bot_alert_time = c_time + continue - if actor and not actor.endswith("[bot]"): + if actor and not actor.endswith("[bot]") and actor != BOT_NAME: # Use edit time if available, otherwise creation time e_time = c.get("lastEditedAt") actual_time = dateutil.parser.isoparse(e_time) if e_time else c_time @@ -263,7 +265,7 @@ def _build_history_timeline( if not e: continue actor = e.get("editor", {}).get("login") - if actor and not actor.endswith("[bot]"): + if actor and not actor.endswith("[bot]") and actor != BOT_NAME: history.append({ "type": "edited_description", "actor": actor, @@ -285,7 +287,7 @@ def _build_history_timeline( label_events.append(time_val) continue - if actor and not actor.endswith("[bot]"): + if actor and not actor.endswith("[bot]") and actor != BOT_NAME: pretty_type = ( "renamed_title" if etype == "RenamedTitleEvent" else "reopened" ) diff --git a/pyproject.toml b/pyproject.toml index af7e1840a4..19abaa3f70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,7 @@ classifiers = [ # List of https://pypi.org/classifiers/ dependencies = [ # go/keep-sorted start "PyYAML>=6.0.2, <7.0.0", # For APIHubToolset. - # TODO: Update aiosqlite version once https://github.com/omnilib/aiosqlite/issues/369 is fixed. - "aiosqlite==0.21.0", # For SQLite database + "aiosqlite>=0.21.0", # For SQLite database "anyio>=4.9.0, <5.0.0", # For MCP Session Manager "authlib>=1.5.1, <2.0.0", # For RestAPI Tool "click>=8.1.8, <9.0.0", # For CLI tools @@ -110,6 +109,7 @@ eval = [ "google-cloud-aiplatform[evaluation]>=1.100.0", "pandas>=2.2.3", "rouge-score>=0.1.2", + "scipy<1.16; python_version<'3.11'", "tabulate>=0.9.0", # go/keep-sorted end ] diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 5d71591466..96d84163c7 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1657,7 +1657,8 @@ async def process_messages(): for task in done: task.result() except WebSocketDisconnect: - logger.info("Client disconnected during process_messages.") + # Disconnection could happen when receive or send text via websocket + logger.info("Client disconnected during live session.") except Exception as e: logger.exception("Error during live websocket communication: %s", e) traceback.print_exc() diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 131213ec07..b5abaa503b 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -211,75 +211,214 @@ def tear_down_observer(observer: Observer, _: AdkWebServer): **extra_fast_api_args, ) + agents_base_path = (Path.cwd() / agents_dir).resolve() + + def _get_app_root(app_name: str) -> Path: + if app_name in ("", ".", ".."): + raise ValueError(f"Invalid app name: {app_name!r}") + if Path(app_name).name != app_name or "\\" in app_name: + raise ValueError(f"Invalid app name: {app_name!r}") + app_root = (agents_base_path / app_name).resolve() + if not app_root.is_relative_to(agents_base_path): + raise ValueError(f"Invalid app name: {app_name!r}") + return app_root + + def _normalize_relative_path(path: str) -> str: + return path.replace("\\", "/").lstrip("/") + + def _has_parent_reference(path: str) -> bool: + return any(part == ".." for part in path.split("/")) + + def _parse_upload_filename(filename: Optional[str]) -> tuple[str, str]: + if not filename: + raise ValueError("Upload filename is missing.") + filename = _normalize_relative_path(filename) + if "/" not in filename: + raise ValueError(f"Invalid upload filename: {filename!r}") + app_name, rel_path = filename.split("/", 1) + if not app_name or not rel_path: + raise ValueError(f"Invalid upload filename: {filename!r}") + if rel_path.startswith("/"): + raise ValueError(f"Absolute upload path rejected: {filename!r}") + if _has_parent_reference(rel_path): + raise ValueError(f"Path traversal rejected: {filename!r}") + return app_name, rel_path + + def _parse_file_path(file_path: str) -> str: + file_path = _normalize_relative_path(file_path) + if not file_path: + raise ValueError("file_path is missing.") + if file_path.startswith("/"): + raise ValueError(f"Absolute file_path rejected: {file_path!r}") + if _has_parent_reference(file_path): + raise ValueError(f"Path traversal rejected: {file_path!r}") + return file_path + + def _resolve_under_dir(root_dir: Path, rel_path: str) -> Path: + file_path = root_dir / rel_path + resolved_root_dir = root_dir.resolve() + resolved_file_path = file_path.resolve() + if not resolved_file_path.is_relative_to(resolved_root_dir): + raise ValueError(f"Path escapes root_dir: {rel_path!r}") + return file_path + + def _get_tmp_agent_root(app_root: Path, app_name: str) -> Path: + tmp_agent_root = app_root / "tmp" / app_name + resolved_tmp_agent_root = tmp_agent_root.resolve() + if not resolved_tmp_agent_root.is_relative_to(app_root): + raise ValueError(f"Invalid tmp path for app: {app_name!r}") + return tmp_agent_root + + def copy_dir_contents(source_dir: Path, dest_dir: Path) -> None: + dest_dir.mkdir(parents=True, exist_ok=True) + for source_path in source_dir.iterdir(): + if source_path.name == "tmp": + continue + + dest_path = dest_dir / source_path.name + if source_path.is_dir(): + if dest_path.exists() and dest_path.is_file(): + dest_path.unlink() + shutil.copytree(source_path, dest_path, dirs_exist_ok=True) + elif source_path.is_file(): + if dest_path.exists() and dest_path.is_dir(): + shutil.rmtree(dest_path) + shutil.copy2(source_path, dest_path) + + def cleanup_tmp(app_name: str) -> bool: + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in cleanup_tmp: %s", exc) + return False + + try: + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + except ValueError as exc: + logger.exception("Error in cleanup_tmp: %s", exc) + return False + + try: + shutil.rmtree(tmp_agent_root) + except FileNotFoundError: + pass + except OSError as exc: + logger.exception("Error deleting tmp agent root: %s", exc) + return False + + tmp_dir = app_root / "tmp" + resolved_tmp_dir = tmp_dir.resolve() + if not resolved_tmp_dir.is_relative_to(app_root): + logger.error( + "Refusing to delete tmp outside app_root: %s", resolved_tmp_dir + ) + return False + + try: + tmp_dir.rmdir() + except OSError: + pass + + return True + + def ensure_tmp_exists(app_name: str) -> bool: + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + if not app_root.is_dir(): + return False + + try: + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + except ValueError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + if tmp_agent_root.exists(): + return True + + try: + tmp_agent_root.mkdir(parents=True, exist_ok=True) + copy_dir_contents(app_root, tmp_agent_root) + except OSError as exc: + logger.exception("Error in ensure_tmp_exists: %s", exc) + return False + + return True + @app.post("/builder/save", response_model_exclude_none=True) async def builder_build( files: list[UploadFile], tmp: Optional[bool] = False ) -> bool: - base_path = Path.cwd() / agents_dir - for file in files: - if not file.filename: - logger.exception("Agent name is missing in the input files") - return False - agent_name, filename = file.filename.split("/") - agent_dir = os.path.join(base_path, agent_name) - try: - # File name format: {app_name}/{agent_name}.yaml - if tmp: - agent_dir = os.path.join(agent_dir, "tmp/" + agent_name) - os.makedirs(agent_dir, exist_ok=True) - file_path = os.path.join(agent_dir, filename) - with open(file_path, "wb") as buffer: + try: + if tmp: + app_names = set() + uploads = [] + for file in files: + app_name, rel_path = _parse_upload_filename(file.filename) + app_names.add(app_name) + uploads.append((rel_path, file)) + + if len(app_names) != 1: + logger.error( + "Exactly one app name is required, found: %s", sorted(app_names) + ) + return False + + app_name = next(iter(app_names)) + app_root = _get_app_root(app_name) + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + tmp_agent_root.mkdir(parents=True, exist_ok=True) + + for rel_path, file in uploads: + destination_path = _resolve_under_dir(tmp_agent_root, rel_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + with destination_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) - else: - source_dir = os.path.join(agent_dir, "tmp/" + agent_name) - destination_dir = agent_dir - for item in os.listdir(source_dir): - source_item = os.path.join(source_dir, item) - destination_item = os.path.join(destination_dir, item) - if os.path.isdir(source_item): - shutil.copytree(source_item, destination_item, dirs_exist_ok=True) - # Check if the item is a file - elif os.path.isfile(source_item): - shutil.copy2(source_item, destination_item) - except Exception as e: - logger.exception("Error in builder_build: %s", e) + return True + + app_names = set() + uploads = [] + for file in files: + app_name, rel_path = _parse_upload_filename(file.filename) + app_names.add(app_name) + uploads.append((rel_path, file)) + + if len(app_names) != 1: + logger.error( + "Exactly one app name is required, found: %s", sorted(app_names) + ) return False - return True + app_name = next(iter(app_names)) + app_root = _get_app_root(app_name) + app_root.mkdir(parents=True, exist_ok=True) + + tmp_agent_root = _get_tmp_agent_root(app_root, app_name) + if tmp_agent_root.is_dir(): + copy_dir_contents(tmp_agent_root, app_root) + + for rel_path, file in uploads: + destination_path = _resolve_under_dir(app_root, rel_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + with destination_path.open("wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + return cleanup_tmp(app_name) + except ValueError as exc: + logger.exception("Error in builder_build: %s", exc) + return False + except OSError as exc: + logger.exception("Error in builder_build: %s", exc) + return False @app.post("/builder/app/{app_name}/cancel", response_model_exclude_none=True) async def builder_cancel(app_name: str) -> bool: - base_path = Path.cwd() / agents_dir - agent_dir = os.path.join(base_path, app_name) - destination_dir = os.path.join(agent_dir, "tmp/" + app_name) - source_dir = agent_dir - source_items = set(os.listdir(source_dir)) - try: - for item in os.listdir(destination_dir): - if item in source_items: - continue - # If it doesn't exist in the source, delete it from the destination - item_path = os.path.join(destination_dir, item) - if os.path.isdir(item_path): - shutil.rmtree(item_path) - elif os.path.isfile(item_path): - os.remove(item_path) - - for item in os.listdir(source_dir): - source_item = os.path.join(source_dir, item) - destination_item = os.path.join(destination_dir, item) - if item == "tmp" and os.path.isdir(source_item): - continue - if os.path.isdir(source_item): - shutil.copytree(source_item, destination_item, dirs_exist_ok=True) - # Check if the item is a file - elif os.path.isfile(source_item): - shutil.copy2(source_item, destination_item) - except Exception as e: - logger.exception("Error in builder_build: %s", e) - return False - return True + return cleanup_tmp(app_name) @app.get( "/builder/app/{app_name}", @@ -291,34 +430,42 @@ async def get_agent_builder( file_path: Optional[str] = None, tmp: Optional[bool] = False, ): - base_path = Path.cwd() / agents_dir - agent_dir = base_path / app_name + try: + app_root = _get_app_root(app_name) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + agent_dir = app_root if tmp: - agent_dir = agent_dir / "tmp" - agent_dir = agent_dir / app_name - if not file_path: - file_name = "root_agent.yaml" - root_file_path = agent_dir / file_name - if not root_file_path.is_file(): + if not ensure_tmp_exists(app_name): return "" - else: - return FileResponse( - path=root_file_path, - media_type="application/x-yaml", - filename="${app_name}.yaml", - headers={"Cache-Control": "no-store"}, - ) + agent_dir = app_root / "tmp" / app_name + + if not file_path: + rel_path = "root_agent.yaml" else: - agent_file_path = agent_dir / file_path - if not agent_file_path.is_file(): + try: + rel_path = _parse_file_path(file_path) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) return "" - else: - return FileResponse( - path=agent_file_path, - media_type="application/x-yaml", - filename=file_path, - headers={"Cache-Control": "no-store"}, - ) + + try: + agent_file_path = _resolve_under_dir(agent_dir, rel_path) + except ValueError as exc: + logger.exception("Error in get_agent_builder: %s", exc) + return "" + + if not agent_file_path.is_file(): + return "" + + return FileResponse( + path=agent_file_path, + media_type="application/x-yaml", + filename=file_path or f"{app_name}.yaml", + headers={"Cache-Control": "no-store"}, + ) if a2a: from a2a.server.apps import A2AStarletteApplication diff --git a/src/google/adk/evaluation/eval_metrics.py b/src/google/adk/evaluation/eval_metrics.py index d937d7fe6d..f81059fb9d 100644 --- a/src/google/adk/evaluation/eval_metrics.py +++ b/src/google/adk/evaluation/eval_metrics.py @@ -24,6 +24,7 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import field_validator from pydantic.json_schema import SkipJsonSchema from typing_extensions import TypeAlias @@ -225,6 +226,17 @@ class MatchType(Enum): ), ) + @field_validator("match_type", mode="before") + @classmethod + def _coerce_match_type(cls, value: object) -> object: + if isinstance(value, cls.MatchType): + return value + if isinstance(value, str): + normalized = value.strip().upper().replace("-", "_").replace(" ", "_") + if normalized in cls.MatchType.__members__: + return cls.MatchType[normalized] + return value + class LlmBackedUserSimulatorCriterion(LlmAsAJudgeCriterion): """Criterion for LLM-backed User Simulator Evaluators.""" diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index ce0df37e39..101b409c13 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -220,13 +220,15 @@ def _rearrange_events_for_latest_function_response( def _is_part_invisible(p: types.Part) -> bool: - """A part is considered invisble if it's a thought, or has no visible content.""" + """Returns whether a part is invisible for LLM context.""" return getattr(p, 'thought', False) or not ( p.text or p.inline_data or p.file_data or p.function_call or p.function_response + or p.executable_code + or p.code_execution_result ) @@ -236,9 +238,8 @@ def _contains_empty_content(event: Event) -> bool: This can happen to the events that only changed session state. When both content and transcriptions are empty, the event will be considered as empty. The content is considered empty if none of its parts contain text, - inline data, file data, function call, or function response. Parts with - only thoughts are also considered empty. - + inline data, file data, function call, function response, executable code, or + code execution result. Parts with only thoughts are also considered empty. Args: event: The event to check. @@ -520,7 +521,7 @@ def _present_other_agent_message(event: Event) -> Optional[Event]: if part.thought: # Exclude thoughts from the context. continue - elif part.text: + elif part.text is not None and part.text.strip(): content.parts.append( types.Part(text=f'[{event.author}] said: {part.text}') ) @@ -543,11 +544,17 @@ def _present_other_agent_message(event: Event) -> Optional[Event]: ) ) ) - # Fallback to the original part for non-text and non-functionCall parts. - else: + elif ( + part.inline_data + or part.file_data + or part.executable_code + or part.code_execution_result + ): content.parts.append(part) + else: + continue - # If no meaningful parts were added (only "For context:" remains), return None + # Return None when only "For context:" remains. if len(content.parts) == 1: return None diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 140473982f..975acc31b8 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -104,6 +104,11 @@ # Providers that require file_id instead of inline file_data _FILE_ID_REQUIRED_PROVIDERS = frozenset({"openai", "azure"}) +_MISSING_TOOL_RESULT_MESSAGE = ( + "Error: Missing tool result (tool execution may have been interrupted " + "before a response was recorded)." +) + def _get_provider_from_model(model: str) -> str: """Extracts the provider name from a LiteLLM model string. @@ -516,6 +521,65 @@ async def _content_to_message_param( ) +def _ensure_tool_results(messages: List[Message]) -> List[Message]: + """Insert placeholder tool messages for missing tool results. + + LiteLLM-backed providers like OpenAI and Anthropic reject histories where an + assistant tool call is not followed by tool responses before the next + non-tool message. This helps recover from interrupted tool execution. + """ + if not messages: + return messages + + healed_messages: List[Message] = [] + pending_tool_call_ids: List[str] = [] + + for message in messages: + role = message.get("role") + if pending_tool_call_ids and role != "tool": + logger.warning( + "Missing tool results for tool_call_id(s): %s", + pending_tool_call_ids, + ) + healed_messages.extend( + ChatCompletionToolMessage( + role="tool", + tool_call_id=tool_call_id, + content=_MISSING_TOOL_RESULT_MESSAGE, + ) + for tool_call_id in pending_tool_call_ids + ) + pending_tool_call_ids = [] + + if role == "assistant": + tool_calls = message.get("tool_calls") or [] + pending_tool_call_ids = [ + tool_call.get("id") for tool_call in tool_calls if tool_call.get("id") + ] + elif role == "tool": + tool_call_id = message.get("tool_call_id") + if tool_call_id in pending_tool_call_ids: + pending_tool_call_ids.remove(tool_call_id) + + healed_messages.append(message) + + if pending_tool_call_ids: + logger.warning( + "Missing tool results for tool_call_id(s): %s", + pending_tool_call_ids, + ) + healed_messages.extend( + ChatCompletionToolMessage( + role="tool", + tool_call_id=tool_call_id, + content=_MISSING_TOOL_RESULT_MESSAGE, + ) + for tool_call_id in pending_tool_call_ids + ) + + return healed_messages + + async def _get_content( parts: Iterable[types.Part], *, @@ -1266,6 +1330,7 @@ async def _get_completion_inputs( content=llm_request.config.system_instruction, ), ) + messages = _ensure_tool_results(messages) # 2. Convert tool declarations tools: Optional[List[Dict]] = None diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 1583026b65..8e0f646de7 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -1736,7 +1736,9 @@ async def before_agent_callback( TraceManager.init_trace(callback_context) TraceManager.push_span(callback_context) await self._log_event( - "AGENT_STARTING", callback_context, raw_content=agent.instruction + "AGENT_STARTING", + callback_context, + raw_content=getattr(agent, "instruction", ""), ) async def after_agent_callback( diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py index b778de02ad..8b12f92fc1 100644 --- a/src/google/adk/plugins/context_filter_plugin.py +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -14,11 +14,14 @@ from __future__ import annotations +from collections.abc import Sequence import logging from typing import Callable from typing import List from typing import Optional +from google.genai import types + from ..agents.callback_context import CallbackContext from ..events.event import Event from ..models.llm_request import LlmRequest @@ -28,6 +31,37 @@ logger = logging.getLogger("google_adk." + __name__) +def _adjust_split_index_to_avoid_orphaned_function_responses( + contents: Sequence[types.Content], split_index: int +) -> int: + """Moves `split_index` left until function calls/responses stay paired. + + When truncating context, we must avoid keeping a `function_response` while + dropping its matching preceding `function_call`. + + Args: + contents: Full conversation contents in chronological order. + split_index: Candidate split index (keep `contents[split_index:]`). + + Returns: + A (possibly smaller) split index that preserves call/response pairs. + """ + needed_call_ids = set() + for i in range(len(contents) - 1, -1, -1): + parts = contents[i].parts + if parts: + for part in reversed(parts): + if part.function_response and part.function_response.id: + needed_call_ids.add(part.function_response.id) + if part.function_call and part.function_call.id: + needed_call_ids.discard(part.function_call.id) + + if i <= split_index and not needed_call_ids: + return i + + return 0 + + class ContextFilterPlugin(BasePlugin): """A plugin that filters the LLM context to reduce its size.""" @@ -76,6 +110,12 @@ async def before_model_callback( start_index -= 1 split_index = start_index break + # Adjust split_index to avoid orphaned function_responses. + split_index = ( + _adjust_split_index_to_avoid_orphaned_function_responses( + contents, split_index + ) + ) contents = contents[split_index:] if self._custom_filter: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1773729719..8723ea2e38 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -84,6 +84,19 @@ def _has_non_empty_transcription_text(transcription) -> bool: ) +def _apply_run_config_custom_metadata( + event: Event, run_config: RunConfig | None +) -> None: + """Merges run-level custom metadata into the event, if present.""" + if not run_config or not run_config.custom_metadata: + return + + event.custom_metadata = { + **run_config.custom_metadata, + **(event.custom_metadata or {}), + } + + class Runner: """The Runner class is used to run agents. @@ -695,6 +708,9 @@ async def _exec_with_plugin( author='model', content=early_exit_result, ) + _apply_run_config_custom_metadata( + early_exit_event, invocation_context.run_config + ) if self._should_append_event(early_exit_event, is_live_call): await self.session_service.append_event( session=session, @@ -721,6 +737,9 @@ async def _exec_with_plugin( async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) if is_live_call: if event.partial and _is_transcription(event): is_transcribing = True @@ -775,7 +794,13 @@ async def _exec_with_plugin( modified_event = await plugin_manager.run_on_event_callback( invocation_context=invocation_context, event=event ) - yield (modified_event if modified_event else event) + if modified_event: + _apply_run_config_custom_metadata( + modified_event, invocation_context.run_config + ) + yield modified_event + else: + yield event # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. @@ -846,6 +871,7 @@ async def _append_new_message_to_session( author='user', content=new_message, ) + _apply_run_config_custom_metadata(event, invocation_context.run_config) # If new_message is a function response, find the matching function call # and use its branch as the new event's branch. if function_call := invocation_context._find_matching_function_call(event): diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 3cc9bb6a68..c9762ad0c9 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -25,12 +25,14 @@ from sqlalchemy import event from sqlalchemy import select from sqlalchemy import text +from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.inspection import inspect +from sqlalchemy.pool import StaticPool from typing_extensions import override from tzlocal import get_localzone @@ -103,7 +105,15 @@ def __init__(self, db_url: str, **kwargs: Any): # 2. Create all tables based on schema # 3. Initialize all properties try: - db_engine = create_async_engine(db_url, **kwargs) + engine_kwargs = dict(kwargs) + url = make_url(db_url) + if url.get_backend_name() == "sqlite" and url.database == ":memory:": + engine_kwargs.setdefault("poolclass", StaticPool) + connect_args = dict(engine_kwargs.get("connect_args", {})) + connect_args.setdefault("check_same_thread", False) + engine_kwargs["connect_args"] = connect_args + + db_engine = create_async_engine(db_url, **engine_kwargs) if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) @@ -477,3 +487,15 @@ async def append_event(self, session: Session, event: Event) -> Event: # Also update the in-memory session await super().append_event(session=session, event=event) return event + + async def close(self) -> None: + """Disposes the SQLAlchemy engine and closes pooled connections.""" + await self.db_engine.dispose() + + async def __aenter__(self) -> DatabaseSessionService: + """Enters the async context manager and returns this service.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exits the async context manager and closes the service.""" + await self.close() diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index 16a11218d7..a69c29243a 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -310,7 +310,11 @@ def to_event(self) -> Event: branch=self.branch, # This is needed as previous ADK version pickled actions might not have # value defined in the current version of the EventActions model. - actions=EventActions().model_copy(update=self.actions.model_dump()), + actions=( + EventActions.model_validate(self.actions.model_dump()) + if self.actions + else EventActions() + ), timestamp=self.timestamp.timestamp(), long_running_tool_ids=self.long_running_tool_ids, partial=self.partial, diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 5bcd734e70..1f03835ba7 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -229,7 +229,7 @@ def execute_sql( >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [ @@ -253,7 +253,7 @@ def execute_sql( >>> execute_sql( ... "my_project", ... "SELECT island FROM " - ... "bigquery-public-data.ml_datasets.penguins", + ... "`bigquery-public-data`.`ml_datasets`.`penguins`", ... dry_run=True ... ) { @@ -269,7 +269,7 @@ def execute_sql( "tableId": "anon..." }, "priority": "INTERACTIVE", - "query": "SELECT island FROM bigquery-public-data.ml_datasets.penguins", + "query": "SELECT island FROM `bigquery-public-data`.`ml_datasets`.`penguins`", "useLegacySql": False, "writeDisposition": "WRITE_TRUNCATE" } @@ -319,7 +319,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [ @@ -343,7 +343,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: >>> execute_sql( ... "my_project", ... "SELECT island FROM " - ... "bigquery-public-data.ml_datasets.penguins", + ... "`bigquery-public-data`.`ml_datasets`.`penguins`", ... dry_run=True ... ) { @@ -359,7 +359,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: "tableId": "anon..." }, "priority": "INTERACTIVE", - "query": "SELECT island FROM bigquery-public-data.ml_datasets.penguins", + "query": "SELECT island FROM `bigquery-public-data`.`ml_datasets`.`penguins`", "useLegacySql": False, "writeDisposition": "WRITE_TRUNCATE" } @@ -374,7 +374,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Create a table with schema prescribed: >>> execute_sql("my_project", - ... "CREATE TABLE my_project.my_dataset.my_table " + ... "CREATE TABLE `my_project`.`my_dataset`.`my_table` " ... "(island STRING, population INT64)") { "status": "SUCCESS", @@ -384,7 +384,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Insert data into an existing table: >>> execute_sql("my_project", - ... "INSERT INTO my_project.my_dataset.my_table (island, population) " + ... "INSERT INTO `my_project`.`my_dataset`.`my_table` (island, population) " ... "VALUES ('Dream', 124), ('Biscoe', 168)") { "status": "SUCCESS", @@ -394,9 +394,9 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Create a table from the result of a query: >>> execute_sql("my_project", - ... "CREATE TABLE my_project.my_dataset.my_table AS " + ... "CREATE TABLE `my_project`.`my_dataset`.`my_table` AS " ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [] @@ -405,7 +405,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Delete a table: >>> execute_sql("my_project", - ... "DROP TABLE my_project.my_dataset.my_table") + ... "DROP TABLE `my_project`.`my_dataset`.`my_table`") { "status": "SUCCESS", "rows": [] @@ -414,8 +414,8 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Copy a table to another table: >>> execute_sql("my_project", - ... "CREATE TABLE my_project.my_dataset.my_table_clone " - ... "CLONE my_project.my_dataset.my_table") + ... "CREATE TABLE `my_project`.`my_dataset`.`my_table_clone` " + ... "CLONE `my_project`.`my_dataset`.`my_table`") { "status": "SUCCESS", "rows": [] @@ -425,8 +425,8 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: table: >>> execute_sql("my_project", - ... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot " - ... "CLONE my_project.my_dataset.my_table") + ... "CREATE SNAPSHOT TABLE `my_project`.`my_dataset`.`my_table_snapshot` " + ... "CLONE `my_project`.`my_dataset`.`my_table`") { "status": "SUCCESS", "rows": [] @@ -435,9 +435,9 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Create a BigQuery ML linear regression model: >>> execute_sql("my_project", - ... "CREATE MODEL `my_dataset.my_model` " + ... "CREATE MODEL `my_dataset`.`my_model` " ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS " - ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "SELECT * FROM `bigquery-public-data`.`ml_datasets`.`penguins` " ... "WHERE body_mass_g IS NOT NULL") { "status": "SUCCESS", @@ -447,7 +447,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Evaluate BigQuery ML model: >>> execute_sql("my_project", - ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`)") + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset`.`my_model`)") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -461,8 +461,8 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Evaluate BigQuery ML model on custom data: >>> execute_sql("my_project", - ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset`.`my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -476,8 +476,8 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Predict using BigQuery ML model: >>> execute_sql("my_project", - ... "SELECT * FROM ML.PREDICT(MODEL `my_dataset.my_model`, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.PREDICT(MODEL `my_dataset`.`my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [ @@ -494,7 +494,7 @@ def _execute_sql_write_mode(*args, **kwargs) -> dict: Delete a BigQuery ML model: - >>> execute_sql("my_project", "DROP MODEL `my_dataset.my_model`") + >>> execute_sql("my_project", "DROP MODEL `my_dataset`.`my_model`") { "status": "SUCCESS", "rows": [] @@ -539,7 +539,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [ @@ -563,7 +563,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: >>> execute_sql( ... "my_project", ... "SELECT island FROM " - ... "bigquery-public-data.ml_datasets.penguins", + ... "`bigquery-public-data`.`ml_datasets`.`penguins`", ... dry_run=True ... ) { @@ -579,7 +579,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: "tableId": "anon..." }, "priority": "INTERACTIVE", - "query": "SELECT island FROM bigquery-public-data.ml_datasets.penguins", + "query": "SELECT island FROM `bigquery-public-data`.`ml_datasets`.`penguins`", "useLegacySql": False, "writeDisposition": "WRITE_TRUNCATE" } @@ -594,7 +594,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Create a temporary table with schema prescribed: >>> execute_sql("my_project", - ... "CREATE TEMP TABLE my_table (island STRING, population INT64)") + ... "CREATE TEMP TABLE `my_table` (island STRING, population INT64)") { "status": "SUCCESS", "rows": [] @@ -603,7 +603,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Insert data into an existing temporary table: >>> execute_sql("my_project", - ... "INSERT INTO my_table (island, population) " + ... "INSERT INTO `my_table` (island, population) " ... "VALUES ('Dream', 124), ('Biscoe', 168)") { "status": "SUCCESS", @@ -613,9 +613,9 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Create a temporary table from the result of a query: >>> execute_sql("my_project", - ... "CREATE TEMP TABLE my_table AS " + ... "CREATE TEMP TABLE `my_table` AS " ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [] @@ -623,7 +623,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Delete a temporary table: - >>> execute_sql("my_project", "DROP TABLE my_table") + >>> execute_sql("my_project", "DROP TABLE `my_table`") { "status": "SUCCESS", "rows": [] @@ -632,7 +632,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Copy a temporary table to another temporary table: >>> execute_sql("my_project", - ... "CREATE TEMP TABLE my_table_clone CLONE my_table") + ... "CREATE TEMP TABLE `my_table_clone` CLONE `my_table`") { "status": "SUCCESS", "rows": [] @@ -641,9 +641,9 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Create a temporary BigQuery ML linear regression model: >>> execute_sql("my_project", - ... "CREATE TEMP MODEL my_model " + ... "CREATE TEMP MODEL `my_model` " ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS" - ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "SELECT * FROM `bigquery-public-data`.`ml_datasets`.`penguins` " ... "WHERE body_mass_g IS NOT NULL") { "status": "SUCCESS", @@ -652,7 +652,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Evaluate BigQuery ML model: - >>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL my_model)") + >>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL `my_model`)") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -666,8 +666,8 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Evaluate BigQuery ML model on custom data: >>> execute_sql("my_project", - ... "SELECT * FROM ML.EVALUATE(MODEL my_model, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.EVALUATE(MODEL `my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -681,8 +681,8 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Predict using BigQuery ML model: >>> execute_sql("my_project", - ... "SELECT * FROM ML.PREDICT(MODEL my_model, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.PREDICT(MODEL `my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [ @@ -699,7 +699,7 @@ def _execute_sql_protected_write_mode(*args, **kwargs) -> dict: Delete a BigQuery ML model: - >>> execute_sql("my_project", "DROP MODEL my_model") + >>> execute_sql("my_project", "DROP MODEL `my_model`") { "status": "SUCCESS", "rows": [] diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index c9c4c2ae66..722730e4fa 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -25,6 +25,8 @@ from typing import Any from typing import Dict from typing import Optional +from typing import Protocol +from typing import runtime_checkable from typing import TextIO from typing import Union @@ -33,8 +35,11 @@ from mcp import StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import create_mcp_http_client +from mcp.client.streamable_http import McpHttpClientFactory from mcp.client.streamable_http import streamablehttp_client from pydantic import BaseModel +from pydantic import ConfigDict logger = logging.getLogger('google_adk.' + __name__) @@ -73,6 +78,11 @@ class SseConnectionParams(BaseModel): sse_read_timeout: float = 60 * 5.0 +@runtime_checkable +class CheckableMcpHttpClientFactory(McpHttpClientFactory, Protocol): + pass + + class StreamableHTTPConnectionParams(BaseModel): """Parameters for the MCP Streamable HTTP connection. @@ -88,13 +98,18 @@ class StreamableHTTPConnectionParams(BaseModel): Streamable HTTP server. terminate_on_close: Whether to terminate the MCP Streamable HTTP server when the connection is closed. + httpx_client_factory: Factory function to create a custom HTTPX client. If + not provided, a default factory will be used. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 terminate_on_close: bool = True + httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client def retry_on_errors(func): @@ -275,6 +290,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): seconds=self._connection_params.sse_read_timeout ), terminate_on_close=self._connection_params.terminate_on_close, + httpx_client_factory=self._connection_params.httpx_client_factory, ) else: raise ValueError( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 263e47043a..7a5627ffc3 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -17,6 +17,7 @@ import logging import os from pathlib import Path +import signal import sys import tempfile import time @@ -31,6 +32,7 @@ from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk.cli import fast_api as fast_api_module from google.adk.cli.fast_api import get_fast_api_app from google.adk.errors.input_validation_error import InputValidationError from google.adk.evaluation.eval_case import EvalCase @@ -414,29 +416,41 @@ def test_app( # Patch multiple services and signal handlers with ( - patch("signal.signal", return_value=None), - patch( - "google.adk.cli.fast_api.create_session_service_from_options", + patch.object(signal, "signal", autospec=True, return_value=None), + patch.object( + fast_api_module, + "create_session_service_from_options", + autospec=True, return_value=mock_session_service, ), - patch( - "google.adk.cli.fast_api.create_artifact_service_from_options", + patch.object( + fast_api_module, + "create_artifact_service_from_options", + autospec=True, return_value=mock_artifact_service, ), - patch( - "google.adk.cli.fast_api.create_memory_service_from_options", + patch.object( + fast_api_module, + "create_memory_service_from_options", + autospec=True, return_value=mock_memory_service, ), - patch( - "google.adk.cli.fast_api.AgentLoader", + patch.object( + fast_api_module, + "AgentLoader", + autospec=True, return_value=mock_agent_loader, ), - patch( - "google.adk.cli.fast_api.LocalEvalSetsManager", + patch.object( + fast_api_module, + "LocalEvalSetsManager", + autospec=True, return_value=mock_eval_sets_manager, ), - patch( - "google.adk.cli.fast_api.LocalEvalSetResultsManager", + patch.object( + fast_api_module, + "LocalEvalSetResultsManager", + autospec=True, return_value=mock_eval_set_results_manager, ), ): @@ -459,6 +473,70 @@ def test_app( return client +@pytest.fixture +def builder_test_client( + tmp_path, + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Return a TestClient rooted in a temporary agents directory.""" + with ( + patch.object(signal, "signal", autospec=True, return_value=None), + patch.object( + fast_api_module, + "create_session_service_from_options", + autospec=True, + return_value=mock_session_service, + ), + patch.object( + fast_api_module, + "create_artifact_service_from_options", + autospec=True, + return_value=mock_artifact_service, + ), + patch.object( + fast_api_module, + "create_memory_service_from_options", + autospec=True, + return_value=mock_memory_service, + ), + patch.object( + fast_api_module, + "AgentLoader", + autospec=True, + return_value=mock_agent_loader, + ), + patch.object( + fast_api_module, + "LocalEvalSetsManager", + autospec=True, + return_value=mock_eval_sets_manager, + ), + patch.object( + fast_api_module, + "LocalEvalSetResultsManager", + autospec=True, + return_value=mock_eval_set_results_manager, + ), + ): + app = get_fast_api_app( + agents_dir=str(tmp_path), + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=False, + host="127.0.0.1", + port=8000, + ) + return TestClient(app) + + @pytest.fixture async def create_test_session( test_app, test_session_info, mock_session_service @@ -1175,5 +1253,103 @@ def test_patch_memory(test_app, create_test_session, mock_memory_service): logger.info("Add session to memory test completed successfully") +def test_builder_final_save_preserves_tools_and_cleans_tmp( + builder_test_client, tmp_path +): + files = [ + ("files", ("app/__init__.py", b"from . import agent\n", "text/plain")), + ("files", ("app/tools.py", b"def tool():\n return 1\n", "text/plain")), + ( + "files", + ("app/root_agent.yaml", b"name: app\n", "application/x-yaml"), + ), + ] + response = builder_test_client.post("/builder/save?tmp=true", files=files) + assert response.status_code == 200 + assert response.json() is True + + response = builder_test_client.post( + "/builder/save", + files=[( + "files", + ( + "app/root_agent.yaml", + b"name: app_updated\n", + "application/x-yaml", + ), + )], + ) + assert response.status_code == 200 + assert response.json() is True + + assert (tmp_path / "app" / "tools.py").is_file() + assert not (tmp_path / "app" / "tmp" / "app").exists() + tmp_dir = tmp_path / "app" / "tmp" + assert not tmp_dir.exists() or not any(tmp_dir.iterdir()) + + +def test_builder_cancel_deletes_tmp_idempotent(builder_test_client, tmp_path): + tmp_agent_root = tmp_path / "app" / "tmp" / "app" + tmp_agent_root.mkdir(parents=True, exist_ok=True) + (tmp_agent_root / "root_agent.yaml").write_text("name: app\n") + + response = builder_test_client.post("/builder/app/app/cancel") + assert response.status_code == 200 + assert response.json() is True + assert not (tmp_path / "app" / "tmp").exists() + + response = builder_test_client.post("/builder/app/app/cancel") + assert response.status_code == 200 + assert response.json() is True + assert not (tmp_path / "app" / "tmp").exists() + + +def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): + app_root = tmp_path / "app" + app_root.mkdir(parents=True, exist_ok=True) + (app_root / "root_agent.yaml").write_text("name: app\n") + nested_dir = app_root / "nested" + nested_dir.mkdir(parents=True, exist_ok=True) + (nested_dir / "nested.yaml").write_text("nested: true\n") + + assert not (app_root / "tmp").exists() + response = builder_test_client.get("/builder/app/app?tmp=true") + assert response.status_code == 200 + assert response.text == "name: app\n" + + tmp_agent_root = app_root / "tmp" / "app" + assert (tmp_agent_root / "root_agent.yaml").is_file() + assert (tmp_agent_root / "nested" / "nested.yaml").is_file() + + response = builder_test_client.get( + "/builder/app/app?tmp=true&file_path=nested/nested.yaml" + ) + assert response.status_code == 200 + assert response.text == "nested: true\n" + + +def test_builder_get_tmp_true_missing_app_returns_empty( + builder_test_client, tmp_path +): + response = builder_test_client.get("/builder/app/missing?tmp=true") + assert response.status_code == 200 + assert response.text == "" + assert not (tmp_path / "missing").exists() + + +def test_builder_save_rejects_traversal(builder_test_client, tmp_path): + response = builder_test_client.post( + "/builder/save?tmp=true", + files=[( + "files", + ("app/../escape.yaml", b"nope\n", "application/x-yaml"), + )], + ) + assert response.status_code == 200 + assert response.json() is False + assert not (tmp_path / "escape.yaml").exists() + assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() + + if __name__ == "__main__": pytest.main(["-xvs", __file__]) diff --git a/tests/unittests/evaluation/test_trajectory_evaluator.py b/tests/unittests/evaluation/test_trajectory_evaluator.py index 5edbe06807..47854ca6a5 100644 --- a/tests/unittests/evaluation/test_trajectory_evaluator.py +++ b/tests/unittests/evaluation/test_trajectory_evaluator.py @@ -23,6 +23,7 @@ from google.adk.evaluation.evaluator import EvalStatus from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator from google.genai import types as genai_types +from pydantic import ValidationError import pytest _USER_CONTENT = genai_types.Content( @@ -30,6 +31,71 @@ ) +def test_tool_trajectory_criterion_accepts_string_match_type(): + criterion = ToolTrajectoryCriterion(threshold=0.5, match_type="in_order") + assert criterion.match_type == ToolTrajectoryCriterion.MatchType.IN_ORDER + + +@pytest.mark.parametrize( + ("match_type", "expected"), + [ + ("exact", ToolTrajectoryCriterion.MatchType.EXACT), + ("EXACT", ToolTrajectoryCriterion.MatchType.EXACT), + (" exact ", ToolTrajectoryCriterion.MatchType.EXACT), + ("in order", ToolTrajectoryCriterion.MatchType.IN_ORDER), + ("IN ORDER", ToolTrajectoryCriterion.MatchType.IN_ORDER), + ("In OrDeR", ToolTrajectoryCriterion.MatchType.IN_ORDER), + ("in-order", ToolTrajectoryCriterion.MatchType.IN_ORDER), + ("IN-ORDER", ToolTrajectoryCriterion.MatchType.IN_ORDER), + ("in_order", ToolTrajectoryCriterion.MatchType.IN_ORDER), + ("any order", ToolTrajectoryCriterion.MatchType.ANY_ORDER), + ("ANY ORDER", ToolTrajectoryCriterion.MatchType.ANY_ORDER), + ("any-order", ToolTrajectoryCriterion.MatchType.ANY_ORDER), + ("ANY-ORDER", ToolTrajectoryCriterion.MatchType.ANY_ORDER), + ("any_order", ToolTrajectoryCriterion.MatchType.ANY_ORDER), + ], +) +def test_tool_trajectory_criterion_normalizes_string_match_type( + match_type: str, expected: ToolTrajectoryCriterion.MatchType +): + criterion = ToolTrajectoryCriterion(threshold=0.5, match_type=match_type) + assert criterion.match_type == expected + + +def test_tool_trajectory_criterion_rejects_unknown_string_match_type(): + with pytest.raises(ValidationError): + ToolTrajectoryCriterion(threshold=0.5, match_type="random string") + + +def test_trajectory_evaluator_accepts_string_match_type_from_eval_metric_dict(): + eval_metric = EvalMetric( + threshold=0.5, + metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value, + criterion={ + "threshold": 0.5, + "match_type": "ANY_ORDER", + }, + ) + evaluator = TrajectoryEvaluator(eval_metric=eval_metric) + + tool_call1 = genai_types.FunctionCall(name="test_func1", args={}) + tool_call2 = genai_types.FunctionCall(name="test_func2", args={}) + + actual_invocation = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData(tool_uses=[tool_call1, tool_call2]), + ) + expected_invocation = Invocation( + user_content=_USER_CONTENT, + intermediate_data=IntermediateData(tool_uses=[tool_call2, tool_call1]), + ) + + result = evaluator.evaluate_invocations( + [actual_invocation], [expected_invocation] + ) + assert result.overall_score == 1.0 + + @pytest.fixture def evaluator() -> TrajectoryEvaluator: """Returns a TrajectoryEvaluator.""" diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index bafaebed39..dab0639413 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -572,6 +572,38 @@ async def test_events_with_empty_content_are_skipped(): role="user", ), ), + # Event with content that has executable code part + Event( + invocation_id="inv10", + author="test_agent", + content=types.Content( + parts=[ + types.Part( + executable_code=types.ExecutableCode( + code="print('hello')", + language="PYTHON", + ) + ) + ], + role="model", + ), + ), + # Event with content that has code execution result part + Event( + invocation_id="inv11", + author="test_agent", + content=types.Content( + parts=[ + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome="OUTCOME_OK", + output="hello", + ) + ) + ], + role="model", + ), + ), ] invocation_context.session.events = events @@ -608,4 +640,153 @@ async def test_events_with_empty_content_are_skipped(): parts=[types.Part(text=""), types.Part(text="Mixed content")], role="user", ), + types.Content( + parts=[ + types.Part( + executable_code=types.ExecutableCode( + code="print('hello')", + language="PYTHON", + ) + ) + ], + role="model", + ), + types.Content( + parts=[ + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome="OUTCOME_OK", + output="hello", + ) + ) + ], + role="model", + ), ] + + +@pytest.mark.asyncio +async def test_code_execution_result_events_are_not_skipped(): + """Test that events with code execution result are not skipped. + + This is a regression test for the endless loop bug where code executor + outputs were not passed to the LLM because the events were incorrectly + filtered as empty. + """ + agent = Agent(model="gemini-2.5-flash", name="test_agent") + llm_request = LlmRequest(model="gemini-2.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + events = [ + Event( + invocation_id="inv1", + author="user", + content=types.UserContent("Write code to calculate factorial"), + ), + # Model generates code + Event( + invocation_id="inv2", + author="test_agent", + content=types.Content( + parts=[ + types.Part(text="Here's the code:"), + types.Part( + executable_code=types.ExecutableCode( + code=( + "def factorial(n):\n return 1 if n <= 1 else n *" + " factorial(n-1)\nprint(factorial(5))" + ), + language="PYTHON", + ) + ), + ], + role="model", + ), + ), + # Code execution result + Event( + invocation_id="inv3", + author="test_agent", + content=types.Content( + parts=[ + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome="OUTCOME_OK", + output="120", + ) + ) + ], + role="model", + ), + ), + ] + invocation_context.session.events = events + + # Process the request + async for _ in contents.request_processor.run_async( + invocation_context, llm_request + ): + pass + + # Verify all three events are included, especially the code execution result + assert len(llm_request.contents) == 3 + assert llm_request.contents[0] == types.UserContent( + "Write code to calculate factorial" + ) + # Second event has executable code + assert llm_request.contents[1].parts[1].executable_code is not None + # Third event has code execution result - this was the bug! + assert llm_request.contents[2].parts[0].code_execution_result is not None + assert llm_request.contents[2].parts[0].code_execution_result.output == "120" + + +@pytest.mark.asyncio +async def test_code_execution_result_not_in_first_part_is_not_skipped(): + """Test that code execution results aren't skipped. + + This covers results that appear in a non-first part. + """ + agent = Agent(model="gemini-2.5-flash", name="test_agent") + llm_request = LlmRequest(model="gemini-2.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + events = [ + Event( + invocation_id="inv1", + author="user", + content=types.UserContent("Run some code."), + ), + Event( + invocation_id="inv2", + author="test_agent", + content=types.Content( + parts=[ + types.Part(text=""), + types.Part( + code_execution_result=types.CodeExecutionResult( + outcome="OUTCOME_OK", + output="42", + ) + ), + ], + role="model", + ), + ), + ] + invocation_context.session.events = events + + async for _ in contents.request_processor.run_async( + invocation_context, llm_request + ): + pass + + assert len(llm_request.contents) == 2 + assert any( + part.code_execution_result is not None + and part.code_execution_result.output == "42" + for part in llm_request.contents[1].parts + ) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 4cf0329aa0..ca36966c1e 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -32,6 +32,7 @@ from google.adk.models.lite_llm import _get_content from google.adk.models.lite_llm import _get_provider_from_model from google.adk.models.lite_llm import _message_to_generate_content_response +from google.adk.models.lite_llm import _MISSING_TOOL_RESULT_MESSAGE from google.adk.models.lite_llm import _model_response_to_chunk from google.adk.models.lite_llm import _model_response_to_generate_content_response from google.adk.models.lite_llm import _parse_tool_calls_from_text @@ -470,6 +471,43 @@ async def test_get_completion_inputs_uses_passed_model_for_gemini_format(): assert "response_schema" in response_format +@pytest.mark.asyncio +async def test_get_completion_inputs_inserts_missing_tool_results(): + user_content = types.Content( + role="user", parts=[types.Part.from_text(text="Hi")] + ) + assistant_content = types.Content( + role="assistant", + parts=[ + types.Part.from_text(text="Calling tool."), + types.Part.from_function_call( + name="get_weather", args={"location": "Seoul"} + ), + ], + ) + assistant_content.parts[1].function_call.id = "tool_call_1" + followup_user = types.Content( + role="user", parts=[types.Part.from_text(text="Next question.")] + ) + + llm_request = LlmRequest( + contents=[user_content, assistant_content, followup_user] + ) + messages, _, _, _ = await _get_completion_inputs( + llm_request, model="openai/gpt-4o" + ) + + assert [message["role"] for message in messages] == [ + "user", + "assistant", + "tool", + "user", + ] + tool_message = messages[2] + assert tool_message["tool_call_id"] == "tool_call_1" + assert tool_message["content"] == _MISSING_TOOL_RESULT_MESSAGE + + def test_schema_to_dict_filters_none_enum_values(): # Use model_construct to bypass strict enum validation. top_level_schema = types.Schema.model_construct( diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py index f9c8222ea3..de72b32bf4 100644 --- a/tests/unittests/plugins/test_context_filtering_plugin.py +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -183,3 +183,111 @@ def faulty_filter(contents): ) assert llm_request.contents == original_contents + + +def _create_function_call_content(name: str, call_id: str) -> types.Content: + """Creates a model content with a function call.""" + return types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall(id=call_id, name=name, args={}) + ) + ], + role="model", + ) + + +def _create_function_response_content(name: str, call_id: str) -> types.Content: + """Creates a user content with a function response.""" + return types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + id=call_id, name=name, response={"result": "ok"} + ) + ) + ], + role="user", + ) + + +@pytest.mark.asyncio +async def test_filter_preserves_function_call_response_pairs(): + """Tests that function_call and function_response pairs are kept together. + + This tests the fix for issue #4027 where filtering could create orphaned + function_response messages without their corresponding function_call. + """ + plugin = ContextFilterPlugin(num_invocations_to_keep=2) + + # Simulate conversation from issue #4027: + # user -> model -> user -> model(function_call) -> user(function_response) + # -> model -> user -> model(function_call) -> user(function_response) + contents = [ + _create_content("user", "Hello"), + _create_content("model", "Hi there!"), + _create_content("user", "I want to know about X"), + _create_function_call_content("knowledge_base", "call_1"), + _create_function_response_content("knowledge_base", "call_1"), + _create_content("model", "I found some information..."), + _create_content("user", "can you explain more about Y"), + _create_function_call_content("knowledge_base", "call_2"), + _create_function_response_content("knowledge_base", "call_2"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # Verify function_call for call_1 is included (not orphaned function_response) + call_ids_present = set() + response_ids_present = set() + for content in llm_request.contents: + if content.parts: + for part in content.parts: + if part.function_call and part.function_call.id: + call_ids_present.add(part.function_call.id) + if part.function_response and part.function_response.id: + response_ids_present.add(part.function_response.id) + + # Every function_response should have a matching function_call + assert response_ids_present.issubset(call_ids_present), ( + "Orphaned function_responses found. " + f"Responses: {response_ids_present}, Calls: {call_ids_present}" + ) + + +@pytest.mark.asyncio +async def test_filter_with_nested_function_calls(): + """Tests filtering with multiple nested function call sequences.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1) + + contents = [ + _create_content("user", "Hello"), + _create_content("model", "Hi!"), + _create_content("user", "Do task"), + _create_function_call_content("tool_a", "call_a"), + _create_function_response_content("tool_a", "call_a"), + _create_function_call_content("tool_b", "call_b"), + _create_function_response_content("tool_b", "call_b"), + _create_content("model", "Done with tasks"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # Verify no orphaned function_responses + call_ids = set() + response_ids = set() + for content in llm_request.contents: + if content.parts: + for part in content.parts: + if part.function_call and part.function_call.id: + call_ids.add(part.function_call.id) + if part.function_response and part.function_response.id: + response_ids.add(part.function_response.id) + + assert response_ids.issubset(call_ids) diff --git a/tests/unittests/sessions/migration/test_database_schema.py b/tests/unittests/sessions/migration/test_database_schema.py index 4fc0d03d96..239da2f1e2 100644 --- a/tests/unittests/sessions/migration/test_database_schema.py +++ b/tests/unittests/sessions/migration/test_database_schema.py @@ -29,17 +29,20 @@ async def create_v0_db(db_path): await engine.dispose() +# Use async context managers so DatabaseSessionService always closes. + + @pytest.mark.asyncio async def test_new_db_uses_latest_schema(tmp_path): db_path = tmp_path / 'new_db.db' db_url = f'sqlite+aiosqlite:///{db_path}' - session_service = DatabaseSessionService(db_url) - assert session_service._db_schema_version is None - await session_service.create_session(app_name='my_app', user_id='test_user') - assert ( - session_service._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ) + async with DatabaseSessionService(db_url) as session_service: + assert session_service._db_schema_version is None + await session_service.create_session(app_name='my_app', user_id='test_user') + assert ( + session_service._db_schema_version + == _schema_check_utils.LATEST_SCHEMA_VERSION + ) # Verify metadata table engine = create_async_engine(db_url) @@ -71,21 +74,20 @@ async def test_existing_v0_db_uses_v0_schema(tmp_path): db_path = tmp_path / 'v0_db.db' await create_v0_db(db_path) db_url = f'sqlite+aiosqlite:///{db_path}' - session_service = DatabaseSessionService(db_url) - - assert session_service._db_schema_version is None - await session_service.create_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert ( - session_service._db_schema_version - == _schema_check_utils.SCHEMA_VERSION_0_PICKLE - ) - - session = await session_service.get_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert session.id == 's1' + async with DatabaseSessionService(db_url) as session_service: + assert session_service._db_schema_version is None + await session_service.create_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert ( + session_service._db_schema_version + == _schema_check_utils.SCHEMA_VERSION_0_PICKLE + ) + + session = await session_service.get_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert session.id == 's1' # Verify schema tables engine = create_async_engine(db_url) @@ -111,38 +113,38 @@ async def test_existing_latest_db_uses_latest_schema(tmp_path): db_url = f'sqlite+aiosqlite:///{db_path}' # Create session service which creates db with latest schema - session_service1 = DatabaseSessionService(db_url) - await session_service1.create_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert ( - session_service1._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ) - - # Create another session service on same db and check it detects latest schema - session_service2 = DatabaseSessionService(db_url) - await session_service2.create_session( - app_name='my_app', user_id='test_user2', session_id='s2' - ) - assert ( - session_service2._db_schema_version - == _schema_check_utils.LATEST_SCHEMA_VERSION - ) - s2 = await session_service2.get_session( - app_name='my_app', user_id='test_user2', session_id='s2' - ) - assert s2.id == 's2' - - s1 = await session_service2.get_session( - app_name='my_app', user_id='test_user', session_id='s1' - ) - assert s1.id == 's1' - - list_sessions_response = await session_service2.list_sessions( - app_name='my_app' - ) - assert len(list_sessions_response.sessions) == 2 + async with DatabaseSessionService(db_url) as session_service1: + await session_service1.create_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert ( + session_service1._db_schema_version + == _schema_check_utils.LATEST_SCHEMA_VERSION + ) + + # Create another session service on same db and check it detects latest schema + async with DatabaseSessionService(db_url) as session_service2: + await session_service2.create_session( + app_name='my_app', user_id='test_user2', session_id='s2' + ) + assert ( + session_service2._db_schema_version + == _schema_check_utils.LATEST_SCHEMA_VERSION + ) + s2 = await session_service2.get_session( + app_name='my_app', user_id='test_user2', session_id='s2' + ) + assert s2.id == 's2' + + s1 = await session_service2.get_session( + app_name='my_app', user_id='test_user', session_id='s1' + ) + assert s1.id == 's1' + + list_sessions_response = await session_service2.list_sessions( + app_name='my_app' + ) + assert len(list_sessions_response.sessions) == 2 # Verify schema tables engine = create_async_engine(db_url) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 45aa3feede..556d78ae57 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -45,33 +45,30 @@ def get_session_service( return InMemorySessionService() -@pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ +@pytest.fixture( + params=[ SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE, SessionServiceType.SQLITE, - ], + ] ) -async def test_get_empty_session(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def session_service(request, tmp_path): + """Provides a session service and closes database backends on teardown.""" + service = get_session_service(request.param, tmp_path) + yield service + if isinstance(service, DatabaseSessionService): + await service.close() + + +@pytest.mark.asyncio +async def test_get_empty_session(session_service): assert not await session_service.get_session( app_name='my_app', user_id='test_user', session_id='123' ) @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_create_get_session(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_create_get_session(session_service): app_name = 'my_app' user_id = 'test_user' state = {'key': 'value'} @@ -111,16 +108,7 @@ async def test_create_get_session(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_create_and_list_sessions(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_create_and_list_sessions(session_service): app_name = 'my_app' user_id = 'test_user' @@ -144,16 +132,7 @@ async def test_create_and_list_sessions(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_list_sessions_all_users(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_list_sessions_all_users(session_service): app_name = 'my_app' user_id_1 = 'user1' user_id_2 = 'user2' @@ -209,16 +188,7 @@ async def test_list_sessions_all_users(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_app_state_is_shared_by_all_users_of_app(session_service): app_name = 'my_app' # User 1 creates a session, establishing app:k1 session1 = await session_service.create_session( @@ -247,18 +217,7 @@ async def test_app_state_is_shared_by_all_users_of_app(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_user_state_is_shared_only_by_user_sessions( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_user_state_is_shared_only_by_user_sessions(session_service): app_name = 'my_app' # User 1 creates a session, establishing user:k1 for user 1 session1 = await session_service.create_session( @@ -286,16 +245,7 @@ async def test_user_state_is_shared_only_by_user_sessions( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_session_state_is_not_shared(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_session_state_is_not_shared(session_service): app_name = 'my_app' # User 1 creates a session session1, establishing sk1 only for session1 session1 = await session_service.create_session( @@ -324,18 +274,7 @@ async def test_session_state_is_not_shared(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_temp_state_is_not_persisted_in_state_or_events( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_temp_state_is_not_persisted_in_state_or_events(session_service): app_name = 'my_app' user_id = 'u1' session = await session_service.create_session( @@ -361,16 +300,7 @@ async def test_temp_state_is_not_persisted_in_state_or_events( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_get_session_respects_user_id(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_get_session_respects_user_id(session_service): app_name = 'my_app' # u1 creates session 's1' and adds an event session1 = await session_service.create_session( @@ -392,18 +322,7 @@ async def test_get_session_respects_user_id(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_create_session_with_existing_id_raises_error( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_create_session_with_existing_id_raises_error(session_service): app_name = 'my_app' user_id = 'test_user' session_id = 'existing_session' @@ -425,16 +344,7 @@ async def test_create_session_with_existing_id_raises_error( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_append_event_bytes(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_append_event_bytes(session_service): app_name = 'my_app' user_id = 'user' @@ -471,16 +381,7 @@ async def test_append_event_bytes(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_append_event_complete(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_append_event_complete(session_service): app_name = 'my_app' user_id = 'user' @@ -532,18 +433,7 @@ async def test_append_event_complete(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_session_last_update_time_updates_on_event( - service_type, tmp_path -): - session_service = get_session_service(service_type, tmp_path) +async def test_session_last_update_time_updates_on_event(session_service): app_name = 'my_app' user_id = 'user' @@ -573,16 +463,7 @@ async def test_session_last_update_time_updates_on_event( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_get_session_with_config(service_type): - session_service = get_session_service(service_type) +async def test_get_session_with_config(session_service): app_name = 'my_app' user_id = 'user' @@ -605,16 +486,7 @@ async def test_get_session_with_config(service_type): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_get_session_with_config(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_get_session_with_config(session_service): app_name = 'my_app' user_id = 'user' @@ -674,16 +546,7 @@ async def test_get_session_with_config(service_type, tmp_path): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'service_type', - [ - SessionServiceType.IN_MEMORY, - SessionServiceType.DATABASE, - SessionServiceType.SQLITE, - ], -) -async def test_partial_events_are_not_persisted(service_type, tmp_path): - session_service = get_session_service(service_type, tmp_path) +async def test_partial_events_are_not_persisted(session_service): app_name = 'my_app' user_id = 'user' session = await session_service.create_session( diff --git a/tests/unittests/sessions/test_v0_storage_event.py b/tests/unittests/sessions/test_v0_storage_event.py new file mode 100644 index 0000000000..6ac62dde10 --- /dev/null +++ b/tests/unittests/sessions/test_v0_storage_event.py @@ -0,0 +1,50 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from datetime import timezone + +from google.adk.events.event_actions import EventActions +from google.adk.events.event_actions import EventCompaction +from google.adk.sessions.schemas.v0 import StorageEvent +from google.genai import types + + +def test_storage_event_v0_to_event_rehydrates_compaction_model(): + compaction = EventCompaction( + start_timestamp=1.0, + end_timestamp=2.0, + compacted_content=types.Content( + role="user", + parts=[types.Part(text="compacted")], + ), + ) + actions = EventActions(compaction=compaction) + storage_event = StorageEvent( + id="event_id", + invocation_id="invocation_id", + author="author", + actions=actions, + session_id="session_id", + app_name="app_name", + user_id="user_id", + timestamp=datetime.fromtimestamp(3.0, tz=timezone.utc), + ) + + event = storage_event.to_event() + + assert event.actions is not None + assert isinstance(event.actions.compaction, EventCompaction) + assert event.actions.compaction.start_timestamp == 1.0 + assert event.actions.compaction.end_timestamp == 2.0 diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index d692f7e380..c347a78931 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -16,6 +16,7 @@ from pathlib import Path import sys import textwrap +from typing import AsyncGenerator from typing import Optional from unittest.mock import AsyncMock @@ -23,6 +24,7 @@ from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -54,7 +56,9 @@ def __init__( if parent_agent: self.parent_agent = parent_agent - async def _run_async_impl(self, invocation_context): + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: yield Event( invocation_id=invocation_context.invocation_id, author=self.name, @@ -78,7 +82,9 @@ def __init__( self.disallow_transfer_to_parent = disallow_transfer_to_parent self.parent_agent = parent_agent - async def _run_async_impl(self, invocation_context): + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: yield Event( invocation_id=invocation_context.invocation_id, author=self.name, @@ -88,6 +94,25 @@ async def _run_async_impl(self, invocation_context): ) +class MockAgentWithMetadata(BaseAgent): + """Mock agent that returns event-level custom metadata.""" + + def __init__(self, name: str): + super().__init__(name=name, sub_agents=[]) + + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=invocation_context.invocation_id, + author=self.name, + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ), + custom_metadata={"event_key": "event_value"}, + ) + + class MockPlugin(BasePlugin): """Mock plugin for unit testing.""" @@ -495,6 +520,41 @@ def test_is_transferable_across_agent_tree_with_non_llm_agent(self): assert result is False +@pytest.mark.asyncio +async def test_run_config_custom_metadata_propagates_to_events(): + session_service = InMemorySessionService() + runner = Runner( + app_name=TEST_APP_ID, + agent=MockAgentWithMetadata("metadata_agent"), + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + await session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + run_config = RunConfig(custom_metadata={"request_id": "req-1"}) + events = [ + event + async for event in runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content(role="user", parts=[types.Part(text="hi")]), + run_config=run_config, + ) + ] + + assert events[0].custom_metadata is not None + assert events[0].custom_metadata["request_id"] == "req-1" + assert events[0].custom_metadata["event_key"] == "event_value" + + session = await session_service.get_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + user_event = next(event for event in session.events if event.author == "user") + assert user_event.custom_metadata == {"request_id": "req-1"} + + class TestRunnerWithPlugins: """Tests for Runner with plugins.""" diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 1791100e1f..a0873046e2 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -114,7 +114,7 @@ async def test_execute_sql_declaration_read_only(tool_settings): >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [ @@ -138,7 +138,7 @@ async def test_execute_sql_declaration_read_only(tool_settings): >>> execute_sql( ... "my_project", ... "SELECT island FROM " - ... "bigquery-public-data.ml_datasets.penguins", + ... "`bigquery-public-data`.`ml_datasets`.`penguins`", ... dry_run=True ... ) { @@ -154,7 +154,7 @@ async def test_execute_sql_declaration_read_only(tool_settings): "tableId": "anon..." }, "priority": "INTERACTIVE", - "query": "SELECT island FROM bigquery-public-data.ml_datasets.penguins", + "query": "SELECT island FROM `bigquery-public-data`.`ml_datasets`.`penguins`", "useLegacySql": False, "writeDisposition": "WRITE_TRUNCATE" } @@ -213,7 +213,7 @@ async def test_execute_sql_declaration_write(tool_settings): >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [ @@ -237,7 +237,7 @@ async def test_execute_sql_declaration_write(tool_settings): >>> execute_sql( ... "my_project", ... "SELECT island FROM " - ... "bigquery-public-data.ml_datasets.penguins", + ... "`bigquery-public-data`.`ml_datasets`.`penguins`", ... dry_run=True ... ) { @@ -253,7 +253,7 @@ async def test_execute_sql_declaration_write(tool_settings): "tableId": "anon..." }, "priority": "INTERACTIVE", - "query": "SELECT island FROM bigquery-public-data.ml_datasets.penguins", + "query": "SELECT island FROM `bigquery-public-data`.`ml_datasets`.`penguins`", "useLegacySql": False, "writeDisposition": "WRITE_TRUNCATE" } @@ -268,7 +268,7 @@ async def test_execute_sql_declaration_write(tool_settings): Create a table with schema prescribed: >>> execute_sql("my_project", - ... "CREATE TABLE my_project.my_dataset.my_table " + ... "CREATE TABLE `my_project`.`my_dataset`.`my_table` " ... "(island STRING, population INT64)") { "status": "SUCCESS", @@ -278,7 +278,7 @@ async def test_execute_sql_declaration_write(tool_settings): Insert data into an existing table: >>> execute_sql("my_project", - ... "INSERT INTO my_project.my_dataset.my_table (island, population) " + ... "INSERT INTO `my_project`.`my_dataset`.`my_table` (island, population) " ... "VALUES ('Dream', 124), ('Biscoe', 168)") { "status": "SUCCESS", @@ -288,9 +288,9 @@ async def test_execute_sql_declaration_write(tool_settings): Create a table from the result of a query: >>> execute_sql("my_project", - ... "CREATE TABLE my_project.my_dataset.my_table AS " + ... "CREATE TABLE `my_project`.`my_dataset`.`my_table` AS " ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [] @@ -299,7 +299,7 @@ async def test_execute_sql_declaration_write(tool_settings): Delete a table: >>> execute_sql("my_project", - ... "DROP TABLE my_project.my_dataset.my_table") + ... "DROP TABLE `my_project`.`my_dataset`.`my_table`") { "status": "SUCCESS", "rows": [] @@ -308,8 +308,8 @@ async def test_execute_sql_declaration_write(tool_settings): Copy a table to another table: >>> execute_sql("my_project", - ... "CREATE TABLE my_project.my_dataset.my_table_clone " - ... "CLONE my_project.my_dataset.my_table") + ... "CREATE TABLE `my_project`.`my_dataset`.`my_table_clone` " + ... "CLONE `my_project`.`my_dataset`.`my_table`") { "status": "SUCCESS", "rows": [] @@ -319,8 +319,8 @@ async def test_execute_sql_declaration_write(tool_settings): table: >>> execute_sql("my_project", - ... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot " - ... "CLONE my_project.my_dataset.my_table") + ... "CREATE SNAPSHOT TABLE `my_project`.`my_dataset`.`my_table_snapshot` " + ... "CLONE `my_project`.`my_dataset`.`my_table`") { "status": "SUCCESS", "rows": [] @@ -329,9 +329,9 @@ async def test_execute_sql_declaration_write(tool_settings): Create a BigQuery ML linear regression model: >>> execute_sql("my_project", - ... "CREATE MODEL `my_dataset.my_model` " + ... "CREATE MODEL `my_dataset`.`my_model` " ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS " - ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "SELECT * FROM `bigquery-public-data`.`ml_datasets`.`penguins` " ... "WHERE body_mass_g IS NOT NULL") { "status": "SUCCESS", @@ -341,7 +341,7 @@ async def test_execute_sql_declaration_write(tool_settings): Evaluate BigQuery ML model: >>> execute_sql("my_project", - ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`)") + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset`.`my_model`)") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -355,8 +355,8 @@ async def test_execute_sql_declaration_write(tool_settings): Evaluate BigQuery ML model on custom data: >>> execute_sql("my_project", - ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset`.`my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -370,8 +370,8 @@ async def test_execute_sql_declaration_write(tool_settings): Predict using BigQuery ML model: >>> execute_sql("my_project", - ... "SELECT * FROM ML.PREDICT(MODEL `my_dataset.my_model`, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.PREDICT(MODEL `my_dataset`.`my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [ @@ -388,7 +388,7 @@ async def test_execute_sql_declaration_write(tool_settings): Delete a BigQuery ML model: - >>> execute_sql("my_project", "DROP MODEL `my_dataset.my_model`") + >>> execute_sql("my_project", "DROP MODEL `my_dataset`.`my_model`") { "status": "SUCCESS", "rows": [] @@ -450,7 +450,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [ @@ -474,7 +474,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): >>> execute_sql( ... "my_project", ... "SELECT island FROM " - ... "bigquery-public-data.ml_datasets.penguins", + ... "`bigquery-public-data`.`ml_datasets`.`penguins`", ... dry_run=True ... ) { @@ -490,7 +490,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): "tableId": "anon..." }, "priority": "INTERACTIVE", - "query": "SELECT island FROM bigquery-public-data.ml_datasets.penguins", + "query": "SELECT island FROM `bigquery-public-data`.`ml_datasets`.`penguins`", "useLegacySql": False, "writeDisposition": "WRITE_TRUNCATE" } @@ -505,7 +505,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Create a temporary table with schema prescribed: >>> execute_sql("my_project", - ... "CREATE TEMP TABLE my_table (island STRING, population INT64)") + ... "CREATE TEMP TABLE `my_table` (island STRING, population INT64)") { "status": "SUCCESS", "rows": [] @@ -514,7 +514,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Insert data into an existing temporary table: >>> execute_sql("my_project", - ... "INSERT INTO my_table (island, population) " + ... "INSERT INTO `my_table` (island, population) " ... "VALUES ('Dream', 124), ('Biscoe', 168)") { "status": "SUCCESS", @@ -524,9 +524,9 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Create a temporary table from the result of a query: >>> execute_sql("my_project", - ... "CREATE TEMP TABLE my_table AS " + ... "CREATE TEMP TABLE `my_table` AS " ... "SELECT island, COUNT(*) AS population " - ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + ... "FROM `bigquery-public-data`.`ml_datasets`.`penguins` GROUP BY island") { "status": "SUCCESS", "rows": [] @@ -534,7 +534,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Delete a temporary table: - >>> execute_sql("my_project", "DROP TABLE my_table") + >>> execute_sql("my_project", "DROP TABLE `my_table`") { "status": "SUCCESS", "rows": [] @@ -543,7 +543,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Copy a temporary table to another temporary table: >>> execute_sql("my_project", - ... "CREATE TEMP TABLE my_table_clone CLONE my_table") + ... "CREATE TEMP TABLE `my_table_clone` CLONE `my_table`") { "status": "SUCCESS", "rows": [] @@ -552,9 +552,9 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Create a temporary BigQuery ML linear regression model: >>> execute_sql("my_project", - ... "CREATE TEMP MODEL my_model " + ... "CREATE TEMP MODEL `my_model` " ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS" - ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "SELECT * FROM `bigquery-public-data`.`ml_datasets`.`penguins` " ... "WHERE body_mass_g IS NOT NULL") { "status": "SUCCESS", @@ -563,7 +563,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Evaluate BigQuery ML model: - >>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL my_model)") + >>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL `my_model`)") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -577,8 +577,8 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Evaluate BigQuery ML model on custom data: >>> execute_sql("my_project", - ... "SELECT * FROM ML.EVALUATE(MODEL my_model, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.EVALUATE(MODEL `my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [{'mean_absolute_error': 227.01223667447218, @@ -592,8 +592,8 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Predict using BigQuery ML model: >>> execute_sql("my_project", - ... "SELECT * FROM ML.PREDICT(MODEL my_model, " - ... "(SELECT * FROM `my_dataset.my_table`))") + ... "SELECT * FROM ML.PREDICT(MODEL `my_model`, " + ... "(SELECT * FROM `my_dataset`.`my_table`))") { "status": "SUCCESS", "rows": [ @@ -610,7 +610,7 @@ async def test_execute_sql_declaration_protected_write(tool_settings): Delete a BigQuery ML model: - >>> execute_sql("my_project", "DROP MODEL my_model") + >>> execute_sql("my_project", "DROP MODEL `my_model`") { "status": "SUCCESS", "rows": [] diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 74eabe9d4d..cac156ae40 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -114,6 +114,54 @@ def test_init_with_streamable_http_params(self): assert manager._connection_params == http_params + @patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client") + def test_init_with_streamable_http_custom_httpx_factory( + self, mock_streamablehttp_client + ): + """Test that streamablehttp_client is called with custom httpx_client_factory.""" + custom_httpx_factory = Mock() + + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", + timeout=15.0, + httpx_client_factory=custom_httpx_factory, + ) + manager = MCPSessionManager(http_params) + + manager._create_client() + + mock_streamablehttp_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=timedelta(seconds=15.0), + sse_read_timeout=timedelta(seconds=300.0), + terminate_on_close=True, + httpx_client_factory=custom_httpx_factory, + ) + + @patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client") + def test_init_with_streamable_http_default_httpx_factory( + self, mock_streamablehttp_client + ): + """Test that streamablehttp_client is called with default httpx_client_factory.""" + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", timeout=15.0 + ) + manager = MCPSessionManager(http_params) + + manager._create_client() + + mock_streamablehttp_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=timedelta(seconds=15.0), + sse_read_timeout=timedelta(seconds=300.0), + terminate_on_close=True, + httpx_client_factory=StreamableHTTPConnectionParams.model_fields[ + "httpx_client_factory" + ].get_default(), + ) + def test_generate_session_key_stdio(self): """Test session key generation for stdio connections.""" manager = MCPSessionManager(self.mock_stdio_connection_params)