diff --git a/contributing/samples/adk_documentation/adk_release_analyzer/agent.py b/contributing/samples/adk_documentation/adk_release_analyzer/agent.py index 738217c3e2..ddad17d310 100644 --- a/contributing/samples/adk_documentation/adk_release_analyzer/agent.py +++ b/contributing/samples/adk_documentation/adk_release_analyzer/agent.py @@ -12,8 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""ADK Release Analyzer Agent - Multi-agent architecture for analyzing releases. + +This agent uses a SequentialAgent + LoopAgent pattern to handle large releases +without context overflow: + +1. PlannerAgent: Collects changed files and creates analysis groups +2. LoopAgent + FileGroupAnalyzer: Processes one group at a time +3. SummaryAgent: Compiles all findings and creates the GitHub issue + +State keys used: +- start_tag, end_tag: Release tags being compared +- compare_url: GitHub compare URL +- file_groups: List of file groups to analyze +- current_group_index: Index of current group being processed +- recommendations: Accumulated recommendations from all groups +""" + import os import sys +from typing import Any SAMPLES_DIR = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "..") @@ -29,12 +47,21 @@ from adk_documentation.settings import LOCAL_REPOS_DIR_PATH from adk_documentation.tools import clone_or_pull_repo from adk_documentation.tools import create_issue -from adk_documentation.tools import get_changed_files_between_releases +from adk_documentation.tools import get_changed_files_summary +from adk_documentation.tools import get_file_diff_for_release from adk_documentation.tools import list_directory_contents from adk_documentation.tools import list_releases from adk_documentation.tools import read_local_git_repo_file_content from adk_documentation.tools import search_local_git_repo from google.adk import Agent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.tools.exit_loop_tool import exit_loop +from google.adk.tools.tool_context import ToolContext + +# Maximum number of files per analysis group to avoid context overflow +MAX_FILES_PER_GROUP = 5 if IS_INTERACTIVE: APPROVAL_INSTRUCTION = ( @@ -43,96 +70,551 @@ ) else: APPROVAL_INSTRUCTION = ( - "**Do not** wait or ask for user approval or confirmation for creating or" - " updating the issue." + "**Do not** wait or ask for user approval or confirmation for creating" + " or updating the issue." ) + +# ============================================================================= +# Tool functions for state management +# ============================================================================= + + +def get_next_file_group(tool_context: ToolContext) -> dict[str, Any]: + """Gets the next group of files to analyze from the state. + + This tool retrieves the next file group from state["file_groups"] + and increments the current_group_index. + + Args: + tool_context: The tool context providing access to state. + + Returns: + A dictionary with the next file group or indication that all groups + are processed. + """ + file_groups = tool_context.state.get("file_groups", []) + current_index = tool_context.state.get("current_group_index", 0) + + if current_index >= len(file_groups): + return { + "status": "complete", + "message": "All file groups have been processed.", + "total_groups": len(file_groups), + "processed": current_index, + } + + current_group = file_groups[current_index] + tool_context.state["current_group_index"] = current_index + 1 + + return { + "status": "success", + "group_index": current_index, + "total_groups": len(file_groups), + "remaining": len(file_groups) - current_index - 1, + "files": current_group, + } + + +def save_group_recommendations( + tool_context: ToolContext, + group_index: int, + recommendations: list[dict[str, str]], +) -> dict[str, Any]: + """Saves recommendations for a file group to state. + + Args: + tool_context: The tool context providing access to state. + group_index: The index of the group these recommendations belong to. + recommendations: List of recommendation dicts with keys: + - summary: Brief summary of the change + - doc_file: Path to the doc file to update + - current_state: Current content in the doc + - proposed_change: What should be changed + - reasoning: Why this change is needed + - reference: Reference to the code file + + Returns: + A dictionary confirming the save operation. + """ + all_recommendations = tool_context.state.get("recommendations", []) + all_recommendations.extend(recommendations) + tool_context.state["recommendations"] = all_recommendations + + return { + "status": "success", + "group_index": group_index, + "new_recommendations": len(recommendations), + "total_recommendations": len(all_recommendations), + } + + +def get_all_recommendations(tool_context: ToolContext) -> dict[str, Any]: + """Retrieves all accumulated recommendations from state. + + Args: + tool_context: The tool context providing access to state. + + Returns: + A dictionary with all recommendations and metadata. + """ + recommendations = tool_context.state.get("recommendations", []) + start_tag = tool_context.state.get("start_tag", "unknown") + end_tag = tool_context.state.get("end_tag", "unknown") + compare_url = tool_context.state.get("compare_url", "") + + return { + "status": "success", + "start_tag": start_tag, + "end_tag": end_tag, + "compare_url": compare_url, + "total_recommendations": len(recommendations), + "recommendations": recommendations, + } + + +def save_release_info( + tool_context: ToolContext, + start_tag: str, + end_tag: str, + compare_url: str, + file_groups: list[list[dict[str, Any]]], + release_summary: str, + all_changed_files: list[str], +) -> dict[str, Any]: + """Saves release info and file groups to state for processing. + + Args: + tool_context: The tool context providing access to state. + start_tag: The starting release tag. + end_tag: The ending release tag. + compare_url: The GitHub compare URL. + file_groups: List of file groups, where each group is a list of file + info dicts. + release_summary: A high-level summary of all changes in this release, + including the main themes (e.g., "new feature X", "refactoring Y", + "bug fixes in Z"). This helps individual analyzers understand the + bigger picture. + all_changed_files: List of all changed file paths (for cross-reference). + + Returns: + A dictionary confirming the save operation. + """ + tool_context.state["start_tag"] = start_tag + tool_context.state["end_tag"] = end_tag + tool_context.state["compare_url"] = compare_url + tool_context.state["file_groups"] = file_groups + tool_context.state["current_group_index"] = 0 + tool_context.state["recommendations"] = [] + tool_context.state["release_summary"] = release_summary + tool_context.state["all_changed_files"] = all_changed_files + + return { + "status": "success", + "start_tag": start_tag, + "end_tag": end_tag, + "total_groups": len(file_groups), + "total_files": sum(len(group) for group in file_groups), + } + + +def get_release_context(tool_context: ToolContext) -> dict[str, Any]: + """Gets the global release context for cross-group awareness. + + This allows individual file group analyzers to understand: + - The overall theme of the release + - What other files were changed (for identifying related changes) + - What recommendations have already been made (to avoid duplicates) + + Args: + tool_context: The tool context providing access to state. + + Returns: + A dictionary with global release context. + """ + return { + "status": "success", + "start_tag": tool_context.state.get("start_tag", "unknown"), + "end_tag": tool_context.state.get("end_tag", "unknown"), + "release_summary": tool_context.state.get("release_summary", ""), + "all_changed_files": tool_context.state.get("all_changed_files", []), + "existing_recommendations": tool_context.state.get("recommendations", []), + "current_group_index": tool_context.state.get("current_group_index", 0), + "total_groups": len(tool_context.state.get("file_groups", [])), + } + + +# ============================================================================= +# Agent 1: Planner Agent +# ============================================================================= + +planner_agent = Agent( + model="gemini-2.5-pro", + name="release_planner", + description=( + "Plans the analysis by fetching release info and organizing files into" + " groups for incremental processing." + ), + instruction=f""" +# 1. Identity +You are the Release Planner, responsible for setting up the analysis of ADK +Python releases. You gather information about changes and organize them for +efficient processing. + +# 2. Workflow +1. First, call `clone_or_pull_repo` for both repositories: + - ADK Python codebase: owner={CODE_OWNER}, repo={CODE_REPO}, path={LOCAL_REPOS_DIR_PATH}/{CODE_REPO} + - ADK Docs: owner={DOC_OWNER}, repo={DOC_REPO}, path={LOCAL_REPOS_DIR_PATH}/{DOC_REPO} + +2. Call `list_releases` to find the release tags for {CODE_OWNER}/{CODE_REPO}. + - By default, compare the two most recent releases. + - If the user specifies tags, use those instead. + +3. Call `get_changed_files_summary` to get the list of changed files WITHOUT + the full patches (to save context space). + +4. Filter and organize the files: + - **INCLUDE** only files in `src/google/adk/` directory + - **EXCLUDE** test files, `__init__.py`, and files outside src/ + - **IMPORTANT**: Do NOT exclude any file just because it has few changes. + Even single-line changes to public APIs need documentation updates. + - **PRIORITIZE** by importance: + a) New files (status: "added") - ALWAYS include these + b) CLI files (cli/) - often contain user-facing flags and options + c) Tool files (tools/) - may contain new tools or tool parameters + d) Core files (agents/, models/, sessions/, memory/, a2a/, flows/, + plugins/, evaluation/) + e) Files with many changes (high additions + deletions) + +5. **Create a high-level release summary** based on the changed files: + - Identify the main themes (e.g., "new tool X added", "refactoring of Y") + - Note any files that appear related (e.g., same feature area) + - This summary will be shared with individual file analyzers so they + understand the bigger picture. + +6. Group the filtered files into groups of at most {MAX_FILES_PER_GROUP} files each. + - **IMPORTANT**: Group RELATED files together (same directory or feature) + - Files that are part of the same feature should be in the same group + - Each group should be independently analyzable + +7. Call `save_release_info` to save: + - start_tag, end_tag + - compare_url + - file_groups (the organized groups) + - release_summary (the high-level summary you created) + - all_changed_files (list of all file paths for cross-reference) + +# 3. Output +Provide a summary of: +- Which releases are being compared +- The high-level themes of this release +- How many files changed in total +- How many files are relevant for doc analysis +- How many groups were created +""", + tools=[ + clone_or_pull_repo, + list_releases, + get_changed_files_summary, + save_release_info, + ], + output_key="planner_output", +) + + +# ============================================================================= +# Agent 2: File Group Analyzer (runs inside LoopAgent) +# ============================================================================= + + +def file_analyzer_instruction(readonly_context: ReadonlyContext) -> str: + """Dynamic instruction that includes current state info.""" + start_tag = readonly_context.state.get("start_tag", "unknown") + end_tag = readonly_context.state.get("end_tag", "unknown") + release_summary = readonly_context.state.get("release_summary", "") + + return f""" +# 1. Identity +You are the File Group Analyzer, responsible for analyzing a group of changed +files and finding related documentation that needs updating. + +# 2. Context +- Comparing releases: {start_tag} to {end_tag} +- Code repository: {CODE_OWNER}/{CODE_REPO} +- Docs repository: {DOC_OWNER}/{DOC_REPO} +- Docs local path: {LOCAL_REPOS_DIR_PATH}/{DOC_REPO} +- Code local path: {LOCAL_REPOS_DIR_PATH}/{CODE_REPO} + +## Release Summary (from Planner) +{release_summary} + +# 3. Workflow +1. Call `get_next_file_group` to get the next group of files to analyze. + - If status is "complete", call the `exit_loop` tool to exit the loop. + +2. **FIRST**, call `get_release_context` to understand: + - The overall release themes (to understand how your files fit in) + - What other files were changed (to identify related changes) + - What recommendations already exist (to AVOID DUPLICATES) + +3. For each file in the group: + a) Call `get_file_diff_for_release` to get the patch content for that file. + b) Analyze the changes THOROUGHLY. Look for: + **API Changes:** + - New functions, classes, methods (especially public ones) + - New parameters added to existing functions + - New CLI arguments or flags (look for argparse, click decorators) + - New environment variables (look for os.environ, getenv) + - New tools or features being added + - Renamed or deprecated functionality + **Behavior Changes (even without API changes):** + - Default values changed + - Error handling or exception types changed + - Return value format or content changed + - Side effects added or removed + - Performance characteristics changed + - Edge case handling changed + - Validation rules changed + c) Consider how this file relates to OTHER changed files in this release. + d) Generate MULTIPLE search patterns based on: + - Class/function names that changed + - Feature names mentioned in the file path + - Keywords from the patch content (e.g., "local_storage", "allow_origins") + - Tool names, parameter names, environment variable names + +4. For EACH significant change, call `search_local_git_repo` to find related docs + in {LOCAL_REPOS_DIR_PATH}/{DOC_REPO}/docs/ + - Search for the feature name, class name, or related keywords + - If no docs found, recommend creating new documentation + +5. Call `read_local_git_repo_file_content` to read the relevant doc files + and check if they need updating. + +6. For each documentation update needed, create a recommendation with: + - summary: Brief summary of what needs to change + - doc_file: Relative path in the docs repo (e.g., docs/tools/google-search.md) + - current_state: What the doc currently says + - proposed_change: What it should say instead + - reasoning: Why this update is needed + - reference: The source code file path + - related_files: Other changed files that are part of the same change (if any) + +7. Call `save_group_recommendations` with all recommendations for this group. + +8. After saving, output a brief summary of what you found for this group. + +# 4. Rules +- **BE THOROUGH**: Check EVERY change in the diff that could affect users. + This includes API changes AND behavior changes (default values, error handling, + return formats, side effects, etc.). +- Focus on changes that users need to know about +- Include behavior changes even if the API signature stays the same +- If a change only affects auto-generated API reference docs, note that + regeneration is needed instead of manual updates +- **AVOID DUPLICATES**: Check existing_recommendations before adding new ones +- **CROSS-REFERENCE**: If files in your group relate to files in other groups, + mention this in your recommendation so the Summary agent can consolidate +- **DON'T MISS ITEMS**: Better to have too many recommendations than too few. + If unsure whether something needs documentation, include it. +- For new features with no existing docs, recommend creating a new page +""" + + +file_group_analyzer = Agent( + model="gemini-2.5-pro", + name="file_group_analyzer", + description=( + "Analyzes a group of changed files and generates recommendations." + ), + instruction=file_analyzer_instruction, + tools=[ + get_next_file_group, + get_release_context, # Get global context to avoid duplicates + get_file_diff_for_release, + search_local_git_repo, + read_local_git_repo_file_content, + list_directory_contents, + save_group_recommendations, + exit_loop, # Call this when all groups are processed + ], + output_key="analyzer_output", +) + +# Loop agent that processes file groups one at a time +file_analysis_loop = LoopAgent( + name="file_analysis_loop", + sub_agents=[file_group_analyzer], + max_iterations=50, # Safety limit +) + + +# ============================================================================= +# Agent 3: Summary Agent +# ============================================================================= + + +def summary_instruction(readonly_context: ReadonlyContext) -> str: + """Dynamic instruction with release info.""" + start_tag = readonly_context.state.get("start_tag", "unknown") + end_tag = readonly_context.state.get("end_tag", "unknown") + + return f""" +# 1. Identity +You are the Summary Agent, responsible for compiling all recommendations into +a well-formatted GitHub issue. + +# 2. Workflow +1. Call `get_all_recommendations` to retrieve all accumulated recommendations. + +2. Organize the recommendations: + - Group by importance: Feature changes > Bug fixes > Other + - Within each group, sort by number of affected files + - Remove duplicates or merge similar recommendations + +3. Format the issue body using this template for each recommendation: + ``` + ### N. **Summary of the change** + + **Doc file**: path/to/doc.md + + **Current state**: + > Current content in the doc + + **Proposed Change**: + > What it should say instead + + **Reasoning**: + Explanation of why this change is necessary. + + **Reference**: src/google/adk/path/to/file.py + ``` + +4. Create the GitHub issue: + - Title: "Found docs updates needed from ADK python release {start_tag} to {end_tag}" + - Include the compare link at the top + - {APPROVAL_INSTRUCTION} + +5. Call `create_issue` for {DOC_OWNER}/{DOC_REPO} with the formatted content. + +# 3. Output +Present a summary of: +- Total recommendations created +- Issue URL if created +- Any notes about the analysis +""" + + +summary_agent = Agent( + model="gemini-2.5-pro", + name="summary_agent", + description="Compiles recommendations and creates the GitHub issue.", + instruction=summary_instruction, + tools=[ + get_all_recommendations, + create_issue, + ], + output_key="summary_output", +) + + +# ============================================================================= +# Pipeline Agent: Sequential orchestration of the analysis +# ============================================================================= + +analysis_pipeline = SequentialAgent( + name="analysis_pipeline", + description=( + "Executes the release analysis pipeline: planning, file analysis, and" + " summary generation." + ), + sub_agents=[ + planner_agent, + file_analysis_loop, + summary_agent, + ], +) + + +# ============================================================================= +# Root Agent: Entry point that understands user requests +# ============================================================================= + root_agent = Agent( model="gemini-2.5-pro", name="adk_release_analyzer", description=( - "Analyze the changes between two ADK releases and generate instructions" - " about how to update the ADK docs." + "Analyzes ADK Python releases and generates documentation update" + " recommendations." ), instruction=f""" - # 1. Identity - You are a helper bot that checks if ADK docs in GitHub Repository {DOC_REPO} owned by {DOC_OWNER} - should be updated based on the changes in the ADK Python codebase in GitHub Repository {CODE_REPO} owned by {CODE_OWNER}. - - You are very familiar with GitHub, especially how to search for files in a GitHub repository using git grep. - - # 2. Responsibilities - Your core responsibility includes: - - Find all the code changes between the two ADK releases. - - Find **all** the related docs files in ADK Docs repository under the "/docs/" directory. - - Compare the code changes with the docs files and analyze the differences. - - Write the instructions about how to update the ADK docs in markdown format and create a GitHub issue in the GitHub Repository {DOC_REPO} with the instructions. - - # 3. Workflow - 1. Always call the `clone_or_pull_repo` tool to make sure the ADK docs and codebase repos exist in the local folder {LOCAL_REPOS_DIR_PATH}/repo_name and are the latest version. - 2. Find the code changes between the two ADK releases. - - You should call the `get_changed_files_between_releases` tool to find all the code changes between the two ADK releases. - - You can call the `list_releases` tool to find the release tags. - 3. Understand the code changes between the two ADK releases. - - You should focus on the main ADK Python codebase, ignore the changes in tests or other auxiliary files. - 4. Come up with a list of regex search patterns to search for related docs files. - 5. Use the `search_local_git_repo` tool to search for related docs files using the regex patterns. - - You should look into all the related docs files, not only the most relevant one. - - Prefer searching from the root directory of the ADK Docs repository (i.e. /docs/), unless you are certain that the file is in a specific directory. - 6. Read the found docs files using the `read_local_git_repo_file_content` tool to find all the docs to update. - - You should read all the found docs files and check if they are up to date. - 7. Compare the code changes and docs files, and analyze the differences. - - You should not only check the code snippets in the docs, but also the text contents. - 8. Write the instructions about how to update the ADK docs in a markdown format. - - For **each** recommended change, reference the code changes. - - For **each** recommended change, follow the format of the following template: - ``` - 1. **Highlighted summary of the change**. - Details of the change. - - **Current state**: - Current content in the doc - - **Proposed Change**: - Proposed change to the doc. - - **Reasoning**: - Explanation of why this change is necessary. - - **Reference**: - Reference to the code file (e.g. src/google/adk/tools/spanner/metadata_tool.py). - ``` - - When referencing doc file, use the full relative path of the doc file in the ADK Docs repository (e.g. docs/sessions/memory.md). - 9. Create or recommend to create a GitHub issue in the GitHub Repository {DOC_REPO} with the instructions using the `create_issue` tool. - - The title of the issue should be "Found docs updates needed from ADK python release to ", where start_tag and end_tag are the release tags. - - The body of the issue should be the instructions about how to update the ADK docs. - - Include the compare link between the two ADK releases in the issue body, e.g. https://github.com/google/adk-python/compare/v1.14.0...v1.14.1. - - **{APPROVAL_INSTRUCTION}** - - # 4. Guidelines & Rules - - **File Paths:** Always use absolute paths when calling the tools to read files, list directories, or search the codebase. - - **Tool Call Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase). - - **Explanation:** Provide concise explanations for your actions and reasoning for each step. - - **Reference:** For each recommended change, reference the code changes (i.e. links to the commits) **AND** the code files (i.e. relative paths to the code files in the codebase). - - **Sorting:** Sort the recommended changes by the importance of the changes, from the most important to the least important. - - Here are the importance groups: Feature changes > Bug fixes > Other changes. - - Within each importance group, sort the changes by the number of files they affect. - - Within each group of changes with the same number of files, sort by the number of lines changed in each file. - - **API Reference Updates:** ADK Docs repository has auto-generated API reference docs for the ADK Python codebase, which can be found in the "/docs/api-reference/python" directory. - - If a change in the codebase can be covered by the auto-generated API reference docs, you should just recommend to update the API reference docs (i.e. regenerate the API reference docs) instead of the other human-written ADK docs. - - # 5. Output - Present the following in an easy to read format as the final output to the user. - - The actions you took and the reasoning - - The summary of the differences found - """, +# 1. Identity +You are the ADK Release Analyzer, a helper bot that analyzes changes between +ADK Python releases and identifies documentation updates needed in the ADK +Docs repository. + +# 2. Capabilities +You can help users in several ways: + +## A. Full Release Analysis (delegate to analysis_pipeline) +When users want a complete analysis of releases, delegate to the +`analysis_pipeline` sub-agent. This will: +- Clone/update repositories +- Analyze all changed files +- Generate recommendations +- Create a GitHub issue + +Use this when users say things like: +- "Analyze the latest releases" +- "Check what docs need updating for v1.15.0" +- "Run a full analysis" + +## B. Quick Queries (use your tools directly) +For targeted questions, use your tools directly WITHOUT delegating: + +- **"How should I modify doc1.md?"** → Use `search_local_git_repo` to find + mentions of doc1.md in the codebase, then use `get_changed_files_summary` + to see what changed, and provide specific guidance. + +- **"What changed in the tools module?"** → Use `get_changed_files_summary` + and filter for tools/ directory. + +- **"Show me the recommendations from the last analysis"** → Use + `get_all_recommendations` to retrieve stored recommendations. + +- **"What releases are available?"** → Use `list_releases` directly. + +# 3. Workflow Decision +1. First, understand what the user is asking: + - Full analysis request → delegate to analysis_pipeline + - Specific question about a file/module → use tools directly + - Query about previous results → use get_all_recommendations + +2. For quick queries, ensure repos are cloned first using `clone_or_pull_repo` + if needed. + +3. Always explain what you're doing and provide clear, actionable answers. + +# 4. Available Tools +- `clone_or_pull_repo`: Ensure local repos are up to date +- `list_releases`: See available release tags +- `get_changed_files_summary`: Get list of changed files (lightweight) +- `get_file_diff_for_release`: Get patch for a specific file +- `search_local_git_repo`: Search for patterns in repos +- `read_local_git_repo_file_content`: Read file contents +- `get_all_recommendations`: Retrieve recommendations from previous analysis + +# 5. Repository Info +- Code repo: {CODE_OWNER}/{CODE_REPO} at {LOCAL_REPOS_DIR_PATH}/{CODE_REPO} +- Docs repo: {DOC_OWNER}/{DOC_REPO} at {LOCAL_REPOS_DIR_PATH}/{DOC_REPO} +""", tools=[ - list_releases, - get_changed_files_between_releases, clone_or_pull_repo, - list_directory_contents, + list_releases, + get_changed_files_summary, + get_file_diff_for_release, search_local_git_repo, read_local_git_repo_file_content, - create_issue, + get_all_recommendations, ], + sub_agents=[analysis_pipeline], ) diff --git a/contributing/samples/adk_documentation/tools.py b/contributing/samples/adk_documentation/tools.py index bc3b8d8c42..c6fd4c2f4d 100644 --- a/contributing/samples/adk_documentation/tools.py +++ b/contributing/samples/adk_documentation/tools.py @@ -548,3 +548,114 @@ def _git_grep( check=False, # Don't raise error on non-zero exit code (1 means no match) ) return grep_process + + +def get_file_diff_for_release( + repo_owner: str, + repo_name: str, + start_tag: str, + end_tag: str, + file_path: str, +) -> Dict[str, Any]: + """Gets the diff/patch for a specific file between two release tags. + + This is useful for incremental processing where you want to analyze + one file at a time instead of loading all changes at once. + + Args: + repo_owner: The name of the repository owner. + repo_name: The name of the repository. + start_tag: The older tag (base) for the comparison. + end_tag: The newer tag (head) for the comparison. + file_path: The relative path of the file to get the diff for. + + Returns: + A dictionary containing the status and the file diff details. + """ + url = f"{GITHUB_BASE_URL}/repos/{repo_owner}/{repo_name}/compare/{start_tag}...{end_tag}" + + try: + comparison_data = get_request(url) + changed_files = comparison_data.get("files", []) + + for file_data in changed_files: + if file_data.get("filename") == file_path: + return { + "status": "success", + "file": { + "relative_path": file_data.get("filename"), + "status": file_data.get("status"), + "additions": file_data.get("additions"), + "deletions": file_data.get("deletions"), + "changes": file_data.get("changes"), + "patch": file_data.get("patch", "No patch available."), + }, + } + + return error_response(f"File {file_path} not found in the comparison.") + except requests.exceptions.HTTPError as e: + return error_response(f"HTTP Error: {e}") + except requests.exceptions.RequestException as e: + return error_response(f"Request Error: {e}") + + +def get_changed_files_summary( + repo_owner: str, repo_name: str, start_tag: str, end_tag: str +) -> Dict[str, Any]: + """Gets a summary of changed files between two releases without patches. + + This is a lighter-weight version of get_changed_files_between_releases + that only returns file paths and metadata, without the actual diff content. + Use this for planning which files to analyze. + + Args: + repo_owner: The name of the repository owner. + repo_name: The name of the repository. + start_tag: The older tag (base) for the comparison. + end_tag: The newer tag (head) for the comparison. + + Returns: + A dictionary containing the status and a summary of changed files. + """ + url = f"{GITHUB_BASE_URL}/repos/{repo_owner}/{repo_name}/compare/{start_tag}...{end_tag}" + + try: + comparison_data = get_request(url) + changed_files = comparison_data.get("files", []) + + # Group files by directory for easier processing + files_by_dir: Dict[str, List[Dict[str, Any]]] = {} + formatted_files = [] + + for file_data in changed_files: + file_info = { + "relative_path": file_data.get("filename"), + "status": file_data.get("status"), + "additions": file_data.get("additions"), + "deletions": file_data.get("deletions"), + "changes": file_data.get("changes"), + } + formatted_files.append(file_info) + + # Group by top-level directory + path = file_data.get("filename", "") + parts = path.split("/") + top_dir = parts[0] if parts else "root" + if top_dir not in files_by_dir: + files_by_dir[top_dir] = [] + files_by_dir[top_dir].append(file_info) + + return { + "status": "success", + "total_files": len(formatted_files), + "files": formatted_files, + "files_by_directory": files_by_dir, + "compare_url": ( + f"https://github.com/{repo_owner}/{repo_name}" + f"/compare/{start_tag}...{end_tag}" + ), + } + except requests.exceptions.HTTPError as e: + return error_response(f"HTTP Error: {e}") + except requests.exceptions.RequestException as e: + return error_response(f"Request Error: {e}") diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 960b6f40c2..f6e3bb66f9 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -24,11 +24,11 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: 5. `get_job_info` Fetches metadata about a BigQuery job. -5. `execute_sql` +6. `execute_sql` Runs or dry-runs a SQL query in BigQuery. -6. `ask_data_insights` +7. `ask_data_insights` Natural language-in, natural language-out tool that answers questions about structured data in BigQuery. Provides a one-stop solution for generating @@ -38,18 +38,18 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: the official [Conversational Analytics API documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview) for instructions. -7. `forecast` +8. `forecast` Perform time series forecasting using BigQuery's `AI.FORECAST` function, leveraging the TimesFM 2.0 model. -8. `analyze_contribution` +9. `analyze_contribution` Perform contribution analysis in BigQuery by creating a temporary `CONTRIBUTION_ANALYSIS` model and then querying it with `ML.GET_INSIGHTS` to find top contributors for a given metric. -9. `detect_anomalies` +10. `detect_anomalies` Perform time series anomaly detection in BigQuery by creating a temporary `ARIMA_PLUS` model and then querying it with diff --git a/contributing/samples/bigquery_mcp/README.md b/contributing/samples/bigquery_mcp/README.md new file mode 100644 index 0000000000..bce19976ca --- /dev/null +++ b/contributing/samples/bigquery_mcp/README.md @@ -0,0 +1,55 @@ +# BigQuery MCP Toolset Sample + +## Introduction + +This sample agent demonstrates using ADK's `McpToolset` to interact with +BigQuery's official MCP endpoint, allowing an agent to access and execute +toole by leveraging the Model Context Protocol (MCP). These tools include: + + +1. `list_dataset_ids` + + Fetches BigQuery dataset ids present in a GCP project. + +2. `get_dataset_info` + + Fetches metadata about a BigQuery dataset. + +3. `list_table_ids` + + Fetches table ids present in a BigQuery dataset. + +4. `get_table_info` + + Fetches metadata about a BigQuery table. + +5. `execute_sql` + + Runs or dry-runs a SQL query in BigQuery. + +## How to use + +Set up your project and local authentication by following the guide +[Use the BigQuery remote MCP server](https://docs.cloud.google.com/bigquery/docs/use-bigquery-mcp). +This agent uses Application Default Credentials (ADC) to authenticate with the +BigQuery MCP endpoint. + +Set up environment variables in your `.env` file for using +[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio) +or +[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai) +for the LLM service for your agent. For example, for using Google AI Studio you +would set: + +* GOOGLE_GENAI_USE_VERTEXAI=FALSE +* GOOGLE_API_KEY={your api key} + +Then run the agent using `adk run .` or `adk web .` in this directory. + +## Sample prompts + +* which weather datasets exist in bigquery public data? +* tell me more about noaa_lightning +* which tables exist in the ml_datasets dataset? +* show more details about the penguins table +* compute penguins population per island. diff --git a/contributing/samples/bigquery_mcp/__init__.py b/contributing/samples/bigquery_mcp/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/bigquery_mcp/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/bigquery_mcp/agent.py b/contributing/samples/bigquery_mcp/agent.py new file mode 100644 index 0000000000..4116bc6cf4 --- /dev/null +++ b/contributing/samples/bigquery_mcp/agent.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import google.auth + +BIGQUERY_AGENT_NAME = "adk_sample_bigquery_mcp_agent" +BIGQUERY_MCP_ENDPOINT = "https://bigquery.googleapis.com/mcp" +BIGQUERY_SCOPE = "https://www.googleapis.com/auth/bigquery" + +# Initialize the tools to use the application default credentials. +# https://cloud.google.com/docs/authentication/provide-credentials-adc +credentials, project_id = google.auth.default(scopes=[BIGQUERY_SCOPE]) +credentials.refresh(google.auth.transport.requests.Request()) +oauth_token = credentials.token + +bigquery_mcp_toolset = McpToolset( + connection_params=StreamableHTTPConnectionParams( + url=BIGQUERY_MCP_ENDPOINT, + headers={"Authorization": f"Bearer {oauth_token}"}, + ) +) + +# The variable name `root_agent` determines what your root agent is for the +# debug CLI +root_agent = LlmAgent( + model="gemini-2.5-flash", + name=BIGQUERY_AGENT_NAME, + description=( + "Agent to answer questions about BigQuery data and models and execute" + " SQL queries using MCP." + ), + instruction="""\ + You are a data science agent with access to several BigQuery tools provided via MCP. + Make use of those tools to answer the user's questions. + """, + tools=[bigquery_mcp_toolset], +) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index dfe6f4a0a2..21428b6381 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -40,6 +40,9 @@ A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result' A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code' +A2A_DATA_PART_TEXT_MIME_TYPE = 'text/plain' +A2A_DATA_PART_START_TAG = b'' +A2A_DATA_PART_END_TAG = b'' A2APartToGenAIPartConverter = Callable[ @@ -130,7 +133,16 @@ def convert_a2a_part_to_genai_part( part.data, by_alias=True ) ) - return genai_types.Part(text=json.dumps(part.data)) + return genai_types.Part( + inline_data=genai_types.Blob( + data=A2A_DATA_PART_START_TAG + + part.model_dump_json(by_alias=True, exclude_none=True).encode( + 'utf-8' + ) + + A2A_DATA_PART_END_TAG, + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, + ) + ) logger.warning( 'Cannot convert unsupported part type: %s for A2A part: %s', @@ -163,6 +175,20 @@ def convert_genai_part_to_a2a_part( ) if part.inline_data: + if ( + part.inline_data.mime_type == A2A_DATA_PART_TEXT_MIME_TYPE + and part.inline_data.data is not None + and part.inline_data.data.startswith(A2A_DATA_PART_START_TAG) + and part.inline_data.data.endswith(A2A_DATA_PART_END_TAG) + ): + return a2a_types.Part( + root=a2a_types.DataPart.model_validate_json( + part.inline_data.data[ + len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) + ] + ) + ) + # The default case for inline_data is to convert it to FileWithBytes. a2a_part = a2a_types.FilePart( file=a2a_types.FileWithBytes( bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 91b4a07b5d..5d7611f217 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1291,7 +1291,6 @@ async def _lifespan(app: FastAPI): host=host, port=port, reload=reload, - log_level=log_level.lower(), ) server = uvicorn.Server(config) @@ -1368,7 +1367,6 @@ def cli_api_server( host=host, port=port, reload=reload, - log_level=log_level.lower(), ) server = uvicorn.Server(config) server.run() diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 327157e2a6..158a5cabc1 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -20,6 +20,7 @@ from google.genai import types +from ..utils.content_utils import filter_audio_parts from ..utils.context_utils import Aclosing from ..utils.variant_utils import GoogleLLMVariant from .base_llm_connection import BaseLlmConnection @@ -63,15 +64,22 @@ async def send_history(self, history: list[types.Content]): # TODO: Remove this filter and translate unary contents to streaming # contents properly. - # We ignore any audio from user during the agent transfer phase + # Filter out audio parts from history because: + # 1. audio has already been transcribed. + # 2. sending audio via connection.send or connection.send_live_content is + # not supported by LIVE API (session will be corrupted). + # This method is called when: + # 1. Agent transfer to a new agent + # 2. Establishing a new live connection with previous ADK session history + contents = [ - content + filtered for content in history - if content.parts and content.parts[0].text + if (filtered := filter_audio_parts(content)) is not None ] - logger.debug('Sending history to live connection: %s', contents) if contents: + logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send( input=types.LiveClientContent( turns=contents, diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c38f854c93..c243f56a6a 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -378,6 +378,13 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: types.Part.from_text(text=llm_request.config.system_instruction) ], ) + + logger.info( + 'Trying to connect to live model: %s with api backend: %s', + llm_request.model, + self._api_backend, + ) + if ( llm_request.live_connect_config.session_resumption and llm_request.live_connect_config.session_resumption.transparent @@ -386,17 +393,13 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: 'session resumption config: %s', llm_request.live_connect_config.session_resumption, ) - logger.debug( - 'self._api_backend: %s', - self._api_backend, - ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: raise ValueError( 'Transparent session resumption is only supported for Vertex AI' ' backend. Please use Vertex AI backend.' ) llm_request.live_connect_config.tools = llm_request.config.tools - logger.info('Connecting to live for model: %s', llm_request.model) logger.debug('Connecting to live with llm_request:%s', llm_request) logger.debug('Live connect config: %s', llm_request.live_connect_config) async with self._live_api_client.aio.live.connect( diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 384d76da88..f6705c1de9 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -110,6 +110,18 @@ ) +def _map_finish_reason( + finish_reason: Any, +) -> types.FinishReason | None: + """Maps a LiteLLM finish_reason value to a google-genai FinishReason enum.""" + if not finish_reason: + return None + if isinstance(finish_reason, types.FinishReason): + return finish_reason + finish_reason_str = str(finish_reason).lower() + return _FINISH_REASON_MAPPING.get(finish_reason_str, types.FinishReason.OTHER) + + def _get_provider_from_model(model: str) -> str: """Extracts the provider name from a LiteLLM model string. @@ -1840,6 +1852,9 @@ async def generate_content_async( else None, ) ) + aggregated_llm_response_with_tool_call.finish_reason = ( + _map_finish_reason(finish_reason) + ) text = "" reasoning_parts = [] function_calls.clear() @@ -1854,6 +1869,9 @@ async def generate_content_async( if reasoning_parts else None, ) + aggregated_llm_response.finish_reason = _map_finish_reason( + finish_reason + ) text = "" reasoning_parts = [] diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index 92df88718a..2b00c79917 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -14,12 +14,15 @@ from __future__ import annotations +import collections.abc import inspect from types import FunctionType import typing from typing import Any from typing import Callable from typing import Dict +from typing import get_args +from typing import get_origin from typing import Optional from typing import Union @@ -391,6 +394,20 @@ def from_function_with_options( return_annotation = inspect.signature(func).return_annotation + # Handle AsyncGenerator and Generator return types (streaming tools) + # AsyncGenerator[YieldType, SendType] -> use YieldType as response schema + # Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema + origin = get_origin(return_annotation) + if origin is not None and ( + origin is collections.abc.AsyncGenerator + or origin is collections.abc.Generator + ): + type_args = get_args(return_annotation) + if type_args: + # First type argument is the yield type + yield_type = type_args[0] + return_annotation = yield_type + # Handle functions with no return annotation if return_annotation is inspect._empty: # Functions with no return annotation can return any type diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py index 862d1cef5a..92e395d4ac 100644 --- a/src/google/adk/tools/base_authenticated_tool.py +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -66,6 +66,7 @@ def __init__( name=name, description=description, ) + self._auth_config = auth_config if auth_config and auth_config.auth_scheme: self._credentials_manager = CredentialManager(auth_config=auth_config) diff --git a/src/google/adk/tools/load_artifacts_tool.py b/src/google/adk/tools/load_artifacts_tool.py index 0e91380517..dbdc1f26f2 100644 --- a/src/google/adk/tools/load_artifacts_tool.py +++ b/src/google/adk/tools/load_artifacts_tool.py @@ -14,6 +14,8 @@ from __future__ import annotations +import base64 +import binascii import json import logging from typing import Any @@ -24,6 +26,19 @@ from .base_tool import BaseTool +# MIME types Gemini accepts for inline data in requests. +_GEMINI_SUPPORTED_INLINE_MIME_PREFIXES = ( + 'image/', + 'audio/', + 'video/', +) +_GEMINI_SUPPORTED_INLINE_MIME_TYPES = frozenset({'application/pdf'}) +_TEXT_LIKE_MIME_TYPES = frozenset({ + 'application/csv', + 'application/json', + 'application/xml', +}) + if TYPE_CHECKING: from ..models.llm_request import LlmRequest from .tool_context import ToolContext @@ -31,6 +46,79 @@ logger = logging.getLogger('google_adk.' + __name__) +def _normalize_mime_type(mime_type: str | None) -> str | None: + """Returns the normalized MIME type, without parameters like charset.""" + if not mime_type: + return None + return mime_type.split(';', 1)[0].strip() + + +def _is_inline_mime_type_supported(mime_type: str | None) -> bool: + """Returns True if Gemini accepts this MIME type as inline data.""" + normalized = _normalize_mime_type(mime_type) + if not normalized: + return False + return normalized.startswith(_GEMINI_SUPPORTED_INLINE_MIME_PREFIXES) or ( + normalized in _GEMINI_SUPPORTED_INLINE_MIME_TYPES + ) + + +def _maybe_base64_to_bytes(data: str) -> bytes | None: + """Best-effort base64 decode for both std and urlsafe formats.""" + try: + return base64.b64decode(data, validate=True) + except (binascii.Error, ValueError): + try: + return base64.urlsafe_b64decode(data) + except (binascii.Error, ValueError): + return None + + +def _as_safe_part_for_llm( + artifact: types.Part, artifact_name: str +) -> types.Part: + """Returns a Part that is safe to send to Gemini.""" + inline_data = artifact.inline_data + if inline_data is None: + return artifact + + if _is_inline_mime_type_supported(inline_data.mime_type): + return artifact + + mime_type = _normalize_mime_type(inline_data.mime_type) or ( + 'application/octet-stream' + ) + data = inline_data.data + if data is None: + return types.Part.from_text( + text=( + f'[Artifact: {artifact_name}, type: {mime_type}. ' + 'No inline data was provided.]' + ) + ) + + if isinstance(data, str): + decoded = _maybe_base64_to_bytes(data) + if decoded is None: + return types.Part.from_text(text=data) + data = decoded + + if mime_type.startswith('text/') or mime_type in _TEXT_LIKE_MIME_TYPES: + try: + return types.Part.from_text(text=data.decode('utf-8')) + except UnicodeDecodeError: + return types.Part.from_text(text=data.decode('utf-8', errors='replace')) + + size_kb = len(data) / 1024 + return types.Part.from_text( + text=( + f'[Binary artifact: {artifact_name}, ' + f'type: {mime_type}, size: {size_kb:.1f} KB. ' + 'Content cannot be displayed inline.]' + ) + ) + + class LoadArtifactsTool(BaseTool): """A tool that loads the artifacts and adds them to the session.""" @@ -108,7 +196,8 @@ async def _append_artifacts_to_llm_request( if llm_request.contents and llm_request.contents[-1].parts: function_response = llm_request.contents[-1].parts[0].function_response if function_response and function_response.name == 'load_artifacts': - artifact_names = function_response.response['artifact_names'] + response = function_response.response or {} + artifact_names = response.get('artifact_names', []) for artifact_name in artifact_names: # Try session-scoped first (default behavior) artifact = await tool_context.load_artifact(artifact_name) @@ -122,6 +211,18 @@ async def _append_artifacts_to_llm_request( if artifact is None: logger.warning('Artifact "%s" not found, skipping', artifact_name) continue + + artifact_part = _as_safe_part_for_llm(artifact, artifact_name) + if artifact_part is not artifact: + mime_type = ( + artifact.inline_data.mime_type if artifact.inline_data else None + ) + logger.debug( + 'Converted artifact "%s" (mime_type=%s) to text Part', + artifact_name, + mime_type, + ) + llm_request.contents.append( types.Content( role='user', @@ -129,7 +230,7 @@ async def _append_artifacts_to_llm_request( types.Part.from_text( text=f'Artifact {artifact_name} is:' ), - artifact, + artifact_part, ], ) ) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 89f0145727..ebd91dc354 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -41,6 +41,8 @@ from pydantic import BaseModel from pydantic import ConfigDict +from .session_context import SessionContext + logger = logging.getLogger('google_adk.' + __name__) @@ -385,29 +387,27 @@ async def create_session( if hasattr(self._connection_params, 'timeout') else None ) + sse_read_timeout_in_seconds = ( + self._connection_params.sse_read_timeout + if hasattr(self._connection_params, 'sse_read_timeout') + else None + ) try: client = self._create_client(merged_headers) - - transports = await asyncio.wait_for( - exit_stack.enter_async_context(client), + is_stdio = isinstance(self._connection_params, StdioConnectionParams) + + session = await asyncio.wait_for( + exit_stack.enter_async_context( + SessionContext( + client=client, + timeout=timeout_in_seconds, + sse_read_timeout=sse_read_timeout_in_seconds, + is_stdio=is_stdio, + ) + ), timeout=timeout_in_seconds, ) - # The streamable http client returns a GetSessionCallback in addition to the - # read/write MemoryObjectStreams needed to build the ClientSession, we limit - # then to the two first values to be compatible with all clients. - if isinstance(self._connection_params, StdioConnectionParams): - session = await exit_stack.enter_async_context( - ClientSession( - *transports[:2], - read_timeout_seconds=timedelta(seconds=timeout_in_seconds), - ) - ) - else: - session = await exit_stack.enter_async_context( - ClientSession(*transports[:2]) - ) - await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds) # Store session and exit stack in the pool self._sessions[session_key] = (session, exit_stack) diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py new file mode 100644 index 0000000000..ca637d0489 --- /dev/null +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -0,0 +1,194 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from datetime import timedelta +import logging +from typing import AsyncContextManager +from typing import Optional + +from mcp import ClientSession + +logger = logging.getLogger('google_adk.' + __name__) + + +class SessionContext: + """Represents the context of a single MCP session within a dedicated task. + + AnyIO's TaskGroup/CancelScope requires that the start and end of a scope + occur within the same task. Since MCP clients use AnyIO internally, we need + to ensure that the client's entire lifecycle (creation, usage, and cleanup) + happens within a single dedicated task. + + This class spawns a background task that: + 1. Enters the MCP client's async context and initializes the session + 2. Signals readiness via an asyncio.Event + 3. Waits for a close signal + 4. Cleans up the client within the same task + + This ensures CancelScope constraints are satisfied regardless of which + task calls start() or close(). + + Can be used in two ways: + 1. Direct method calls: start() and close() + 2. As an async context manager: async with lifecycle as session: ... + """ + + def __init__( + self, + client: AsyncContextManager, + timeout: Optional[float], + sse_read_timeout: Optional[float], + is_stdio: bool = False, + ): + """ + Args: + client: An MCP client context manager (e.g., from streamablehttp_client, + sse_client, or stdio_client). + timeout: Timeout in seconds for connection and initialization. + sse_read_timeout: Timeout in seconds for reading data from the MCP SSE + server. + is_stdio: Whether this is a stdio connection (affects read timeout). + """ + self._client = client + self._timeout = timeout + self._sse_read_timeout = sse_read_timeout + self._is_stdio = is_stdio + self._session: Optional[ClientSession] = None + self._ready_event = asyncio.Event() + self._close_event = asyncio.Event() + self._task: Optional[asyncio.Task] = None + self._task_lock = asyncio.Lock() + + @property + def session(self) -> Optional[ClientSession]: + """Get the managed ClientSession, if available.""" + return self._session + + async def start(self) -> ClientSession: + """Start the runner and wait for the session to be ready. + + Returns: + The initialized ClientSession. + + Raises: + ConnectionError: If session creation fails. + """ + async with self._task_lock: + if self._session: + logger.debug( + 'Session has already been created, returning existing session' + ) + return self._session + + if self._close_event.is_set(): + raise ConnectionError( + 'Failed to create MCP session: session already closed' + ) + + if not self._task: + self._task = asyncio.create_task(self._run()) + + await self._ready_event.wait() + + if self._task.cancelled(): + raise ConnectionError('Failed to create MCP session: task cancelled') + + if self._task.done() and self._task.exception(): + raise ConnectionError( + f'Failed to create MCP session: {self._task.exception()}' + ) from self._task.exception() + + return self._session + + async def close(self): + """Signal the context task to close and wait for cleanup.""" + # Set the close event to signal the task to close. + # Even if start has not been called, we need to set the close event + # to signal the task to close right away. + async with self._task_lock: + self._close_event.set() + + # If start has not been called, only set the close event and return + if not self._task: + return + + if not self._ready_event.is_set(): + self._task.cancel() + + try: + await asyncio.wait_for(self._task, timeout=self._timeout) + except asyncio.TimeoutError: + logger.warning('Failed to close MCP session: task timed out') + self._task.cancel() + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f'Failed to close MCP session: {e}') + + async def __aenter__(self) -> ClientSession: + return await self.start() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def _run(self): + """Run the complete session context within a single task.""" + try: + async with AsyncExitStack() as exit_stack: + transports = await asyncio.wait_for( + exit_stack.enter_async_context(self._client), + timeout=self._timeout, + ) + # The streamable http client returns a GetSessionCallback in addition + # to the read/write MemoryObjectStreams needed to build the + # ClientSession. We limit to the first two values to be compatible + # with all clients. + if self._is_stdio: + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta(seconds=self._timeout) + if self._timeout is not None + else None, + ) + ) + else: + # For SSE and Streamable HTTP clients, use the sse_read_timeout + # instead of the connection timeout as the read_timeout for the session. + session = await exit_stack.enter_async_context( + ClientSession( + *transports[:2], + read_timeout_seconds=timedelta(seconds=self._sse_read_timeout) + if self._sse_read_timeout is not None + else None, + ) + ) + await asyncio.wait_for(session.initialize(), timeout=self._timeout) + logger.debug('Session has been successfully initialized') + + self._session = session + self._ready_event.set() + + # Wait for close signal - the session remains valid while we wait + await self._close_event.wait() + except BaseException as e: + logger.warning(f'Error on session runner task: {e}') + raise + finally: + self._ready_event.set() + self._close_event.set() diff --git a/src/google/adk/utils/content_utils.py b/src/google/adk/utils/content_utils.py new file mode 100644 index 0000000000..379c31ec96 --- /dev/null +++ b/src/google/adk/utils/content_utils.py @@ -0,0 +1,38 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.genai import types + + +def is_audio_part(part: types.Part) -> bool: + return ( + part.inline_data + and part.inline_data.mime_type + and part.inline_data.mime_type.startswith('audio/') + ) or ( + part.file_data + and part.file_data.mime_type + and part.file_data.mime_type.startswith('audio/') + ) + + +def filter_audio_parts(content: types.Content) -> types.Content | None: + if not content.parts: + return None + filtered_parts = [part for part in content.parts if not is_audio_part(part)] + if not filtered_parts: + return None + return types.Content(role=content.role, parts=filtered_parts) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 541ab7709d..00c9ddc5e0 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -17,11 +17,14 @@ from unittest.mock import patch from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_END_TAG from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_START_TAG +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_TEXT_MIME_TYPE from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part from google.adk.a2a.converters.utils import _get_adk_metadata_key @@ -154,12 +157,43 @@ def test_convert_data_part_function_response(self): "data": [1, 2, 3], } - def test_convert_data_part_without_special_metadata(self): - """Test conversion of A2A DataPart without special metadata to text.""" + @pytest.mark.parametrize( + "test_name, data, metadata", + [ + ( + "without_special_metadata", + {"key": "value", "number": 123}, + {"other": "metadata"}, + ), + ( + "no_metadata", + {"key": "value", "array": [1, 2, 3]}, + None, + ), + ( + "complex_data", + { + "nested": { + "array": [1, 2, {"inner": "value"}], + "boolean": True, + "null_value": None, + }, + "unicode": "Hello 世界 🌍", + }, + None, + ), + ( + "empty_metadata", + {"key": "value"}, + {}, + ), + ], + ) + def test_convert_data_part_to_inline_data(self, test_name, data, metadata): + """Test conversion of A2A DataPart to GenAI inline_data Part.""" # Arrange - data = {"key": "value", "number": 123} a2a_part = a2a_types.Part( - root=a2a_types.DataPart(data=data, metadata={"other": "metadata"}) + root=a2a_types.DataPart(data=data, metadata=metadata) ) # Act @@ -168,21 +202,17 @@ def test_convert_data_part_without_special_metadata(self): # Assert assert result is not None assert isinstance(result, genai_types.Part) - assert result.text == json.dumps(data) - - def test_convert_data_part_no_metadata(self): - """Test conversion of A2A DataPart with no metadata to text.""" - # Arrange - data = {"key": "value", "array": [1, 2, 3]} - a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data)) - - # Act - result = convert_a2a_part_to_genai_part(a2a_part) - - # Assert - assert result is not None - assert isinstance(result, genai_types.Part) - assert result.text == json.dumps(data) + assert result.inline_data is not None + assert result.inline_data.mime_type == A2A_DATA_PART_TEXT_MIME_TYPE + assert result.inline_data.data.startswith(A2A_DATA_PART_START_TAG) + assert result.inline_data.data.endswith(A2A_DATA_PART_END_TAG) + converted_data_part = a2a_types.DataPart.model_validate_json( + result.inline_data.data[ + len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) + ] + ) + assert converted_data_part.data == data + assert converted_data_part.metadata == metadata def test_convert_unsupported_file_type(self): """Test handling of unsupported file types.""" @@ -325,6 +355,32 @@ def test_convert_inline_data_part_with_video_metadata(self): assert result.root.metadata is not None assert _get_adk_metadata_key("video_metadata") in result.root.metadata + def test_convert_inline_data_part_to_data_part(self): + """Test conversion of GenAI inline_data Part to A2A DataPart.""" + # Arrange + data = {"key": "value"} + metadata = {"meta": "data"} + a2a_part_to_convert = a2a_types.DataPart(data=data, metadata=metadata) + json_data = a2a_part_to_convert.model_dump_json( + by_alias=True, exclude_none=True + ).encode("utf-8") + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=A2A_DATA_PART_START_TAG + json_data + A2A_DATA_PART_END_TAG, + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + assert result.root.data == data + assert result.root.metadata == metadata + def test_convert_function_call_part(self): """Test conversion of GenAI function_call Part to A2A Part.""" # Arrange @@ -596,6 +652,47 @@ def test_executable_code_round_trip(self): ) assert result_genai_part.executable_code.code == executable_code.code + def test_data_part_round_trip(self): + """Test round-trip conversion for data parts.""" + # Arrange + data = {"key": "value"} + metadata = {"meta": "data"} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata=metadata) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.DataPart) + assert result_a2a_part.root.data == data + assert result_a2a_part.root.metadata == metadata + + def test_data_part_with_mime_type_metadata_round_trip(self): + """Test round-trip conversion for data parts with 'mime_type' in metadata.""" + # Arrange + data = {"content": "some data"} + metadata = {"meta": "data", "mime_type": "application/json"} + a2a_part = a2a_types.Part( + root=a2a_types.DataPart(data=data, metadata=metadata) + ) + + # Act + genai_part = convert_a2a_part_to_genai_part(a2a_part) + result_a2a_part = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result_a2a_part is not None + assert isinstance(result_a2a_part, a2a_types.Part) + assert isinstance(result_a2a_part.root, a2a_types.DataPart) + assert result_a2a_part.root.data == data + # The 'mime_type' key in the metadata should be preserved as is + assert result_a2a_part.root.metadata == metadata + class TestEdgeCases: """Test cases for edge cases and error conditions.""" @@ -612,6 +709,37 @@ def test_empty_text_part(self): assert result is not None assert result.text == "" + def test_genai_inline_data_with_mimetype_to_a2a(self): + """Test conversion of GenAI inline_data with 'mimeType' in DataPart metadata to A2A. + + This tests if 'mimeType' in metadata of a DataPart wrapped in inline_data + is correctly handled, ensuring the key casing is preserved. + """ + # Arrange + data = {"key": "value"} + metadata = {"adk_type": "some_type", "mimeType": "image/png"} + a2a_part_inner = a2a_types.DataPart(data=data, metadata=metadata) + json_data = a2a_part_inner.model_dump_json( + by_alias=True, exclude_none=True + ).encode("utf-8") + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=A2A_DATA_PART_START_TAG + json_data + A2A_DATA_PART_END_TAG, + mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, + ) + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result, a2a_types.Part) + assert isinstance(result.root, a2a_types.DataPart) + assert result.root.data == data + # The key casing should be preserved from the JSON + assert result.root.metadata == metadata + def test_none_input_a2a_to_genai(self): """Test handling of None input for A2A to GenAI conversion.""" # This test depends on how the function handles None input @@ -626,39 +754,6 @@ def test_none_input_genai_to_a2a(self): with pytest.raises(AttributeError): convert_genai_part_to_a2a_part(None) - def test_data_part_with_complex_data(self): - """Test conversion of DataPart with complex nested data.""" - # Arrange - complex_data = { - "nested": { - "array": [1, 2, {"inner": "value"}], - "boolean": True, - "null_value": None, - }, - "unicode": "Hello 世界 🌍", - } - a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=complex_data)) - - # Act - result = convert_a2a_part_to_genai_part(a2a_part) - - # Assert - assert result is not None - assert result.text == json.dumps(complex_data) - - def test_data_part_with_empty_metadata(self): - """Test conversion of DataPart with empty metadata dict.""" - # Arrange - data = {"key": "value"} - a2a_part = a2a_types.Part(root=a2a_types.DataPart(data=data, metadata={})) - - # Act - result = convert_a2a_part_to_genai_part(a2a_part) - - # Assert - assert result is not None - assert result.text == json.dumps(data) - class TestNewConstants: """Test cases for new constants and functionality.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index de8f4f9dad..ac65b2ac2a 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -600,3 +600,177 @@ async def mock_receive_generator(): assert responses[2].output_transcription.text == 'How can I help?' assert responses[2].output_transcription.finished is True assert responses[2].partial is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'audio_part', + [ + types.Part( + inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm') + ), + types.Part( + file_data=types.FileData( + file_uri='artifact://app/user/session/_adk_live/audio.pcm#1', + mime_type='audio/pcm', + ) + ), + ], +) +async def test_send_history_filters_audio(mock_gemini_session, audio_part): + """Test that audio parts (inline or file_data) are filtered out.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[audio_part], + ), + types.Content( + role='model', parts=[types.Part.from_text(text='I heard you')] + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Only the model response should be sent (user audio filtered out) + assert len(sent_contents) == 1 + assert sent_contents[0].role == 'model' + assert sent_contents[0].parts == [types.Part.from_text(text='I heard you')] + + +@pytest.mark.asyncio +async def test_send_history_keeps_image_data(mock_gemini_session): + """Test that image data is NOT filtered out.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + image_blob = types.Blob(data=b'\x89PNG\r\n', mime_type='image/png') + history = [ + types.Content( + role='user', + parts=[types.Part(inline_data=image_blob)], + ), + types.Content( + role='model', parts=[types.Part.from_text(text='Nice image!')] + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Both contents should be sent (image is not filtered) + assert len(sent_contents) == 2 + assert sent_contents[0].parts[0].inline_data == image_blob + + +@pytest.mark.asyncio +async def test_send_history_mixed_content_filters_only_audio( + mock_gemini_session, +): + """Test that mixed content keeps non-audio parts.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00\xFF', mime_type='audio/wav' + ) + ), + types.Part.from_text(text='transcribed text'), + ], + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Content should be sent but only with the text part + assert len(sent_contents) == 1 + assert len(sent_contents[0].parts) == 1 + assert sent_contents[0].parts[0].text == 'transcribed text' + + +@pytest.mark.asyncio +async def test_send_history_all_audio_content_not_sent(mock_gemini_session): + """Test that content with only audio parts is completely removed.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00\xFF', mime_type='audio/pcm' + ) + ), + types.Part( + file_data=types.FileData( + file_uri='artifact://audio.pcm#1', + mime_type='audio/wav', + ) + ), + ], + ), + ] + + await connection.send_history(history) + + # No content should be sent since all parts are audio + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_history_empty_history_not_sent(mock_gemini_session): + """Test that empty history does not call send.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + + await connection.send_history([]) + + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'audio_mime_type', + ['audio/pcm', 'audio/wav', 'audio/mp3', 'audio/ogg'], +) +async def test_send_history_filters_various_audio_mime_types( + mock_gemini_session, + audio_mime_type, +): + """Test that various audio mime types are all filtered.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob(data=b'', mime_type=audio_mime_type) + ) + ], + ), + ] + + await connection.send_history(history) + + # No content should be sent since the only part is audio + mock_gemini_session.send.assert_not_called() diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c687ceb0cb..f6428087b0 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2880,6 +2880,7 @@ async def test_generate_content_async_stream( "test_arg": "test_value" } assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" + assert responses[3].finish_reason == types.FinishReason.STOP assert responses[3].model_version == "test_model" mock_completion.assert_called_once() @@ -2900,6 +2901,55 @@ async def test_generate_content_async_stream( ) +@pytest.mark.asyncio +async def test_generate_content_async_stream_sets_finish_reason( + mock_completion, lite_llm_instance +): + mock_completion.return_value = iter([ + ModelResponse( + model="test_model", + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta(role="assistant", content="Hello "), + ) + ], + ), + ModelResponse( + model="test_model", + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta(role="assistant", content="world"), + ) + ], + ), + ModelResponse( + model="test_model", + choices=[StreamingChoices(finish_reason="stop", delta=Delta())], + ), + ]) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ) + ], + ) + + responses = [ + response + async for response in lite_llm_instance.generate_content_async( + llm_request, stream=True + ) + ] + + assert responses[-1].partial is False + assert responses[-1].finish_reason == types.FinishReason.STOP + assert responses[-1].content.parts[0].text == "Hello world" + + @pytest.mark.asyncio async def test_generate_content_async_stream_with_usage_metadata( mock_completion, lite_llm_instance @@ -2944,6 +2994,7 @@ async def test_generate_content_async_stream_with_usage_metadata( "test_arg": "test_value" } assert responses[3].content.parts[-1].function_call.id == "test_tool_call_id" + assert responses[3].finish_reason == types.FinishReason.STOP assert responses[3].usage_metadata.prompt_token_count == 10 assert responses[3].usage_metadata.candidates_token_count == 5 diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 7e18d9d457..ae91bed13d 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -56,6 +56,33 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class MockSessionContext: + """Mock SessionContext for testing.""" + + def __init__(self, session=None): + """Initialize MockSessionContext. + + Args: + session: The mock session to return from __aenter__ and session property. + """ + self._session = session + self._aenter_mock = AsyncMock(return_value=session) + self._aexit_mock = AsyncMock(return_value=False) + + @property + def session(self): + """Get the mock session.""" + return self._session + + async def __aenter__(self): + """Enter the async context manager.""" + return await self._aenter_mock() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the async context manager.""" + return await self._aexit_mock(exc_type, exc_val, exc_tb) + + class TestMCPSessionManager: """Test suite for MCPSessionManager class.""" @@ -241,7 +268,6 @@ async def test_create_session_stdio_new(self): """Test creating a new stdio session.""" manager = MCPSessionManager(self.mock_stdio_connection_params) - mock_session = MockClientSession() mock_exit_stack = MockAsyncExitStack() with patch( @@ -251,17 +277,19 @@ async def test_create_session_stdio_new(self): "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" ) as mock_exit_stack_class: with patch( - "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" - ) as mock_session_class: + "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext" + ) as mock_session_context_class: # Setup mocks mock_exit_stack_class.return_value = mock_exit_stack mock_stdio.return_value = AsyncMock() - mock_exit_stack.enter_async_context.side_effect = [ - ("read", "write"), # First call returns transports - mock_session, # Second call returns session - ] - mock_session_class.return_value = mock_session + + # Mock SessionContext using MockSessionContext + # Create a mock session that will be returned by SessionContext + mock_session = AsyncMock() + mock_session_context = MockSessionContext(session=mock_session) + mock_session_context_class.return_value = mock_session_context + mock_exit_stack.enter_async_context.return_value = mock_session # Create session session = await manager.create_session() @@ -271,8 +299,10 @@ async def test_create_session_stdio_new(self): assert len(manager._sessions) == 1 assert "stdio_session" in manager._sessions - # Verify session was initialized - mock_session.initialize.assert_called_once() + # Verify SessionContext was created + mock_session_context_class.assert_called_once() + # Verify enter_async_context was called (which internally calls __aenter__) + mock_exit_stack.enter_async_context.assert_called_once() @pytest.mark.asyncio async def test_create_session_reuse_existing(self): @@ -300,39 +330,37 @@ async def test_create_session_reuse_existing(self): @pytest.mark.asyncio @patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client") @patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack") - @patch("google.adk.tools.mcp_tool.mcp_session_manager.ClientSession") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext") async def test_create_session_timeout( - self, mock_session_class, mock_exit_stack_class, mock_stdio + self, mock_session_context_class, mock_exit_stack_class, mock_stdio ): """Test session creation timeout.""" manager = MCPSessionManager(self.mock_stdio_connection_params) - mock_session = MockClientSession() mock_exit_stack = MockAsyncExitStack() mock_exit_stack_class.return_value = mock_exit_stack mock_stdio.return_value = AsyncMock() - mock_exit_stack.enter_async_context.side_effect = [ - ("read", "write"), # First call returns transports - mock_session, # Second call returns session - ] - mock_session_class.return_value = mock_session - # Simulate timeout during session initialization - mock_session.initialize.side_effect = asyncio.TimeoutError("Test timeout") + # Mock SessionContext + mock_session_context = AsyncMock() + mock_session_context.__aenter__ = AsyncMock( + return_value=MockClientSession() + ) + mock_session_context.__aexit__ = AsyncMock(return_value=False) + mock_session_context_class.return_value = mock_session_context + + # Mock enter_async_context to raise TimeoutError (simulating asyncio.wait_for timeout) + mock_exit_stack.enter_async_context = AsyncMock( + side_effect=asyncio.TimeoutError("Test timeout") + ) # Expect ConnectionError due to timeout with pytest.raises(ConnectionError, match="Failed to create MCP session"): await manager.create_session() - # Verify ClientSession called with timeout - mock_session_class.assert_called_with( - "read", - "write", - read_timeout_seconds=timedelta( - seconds=manager._connection_params.timeout - ), - ) + # Verify SessionContext was created + mock_session_context_class.assert_called_once() # Verify session was not added to pool assert not manager._sessions # Verify cleanup was called @@ -390,6 +418,36 @@ async def test_close_with_errors(self): assert "Warning: Error during MCP session cleanup" in error_output assert "Close error 1" in error_output + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack") + @patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext") + async def test_create_and_close_session_in_different_tasks( + self, mock_session_context_class, mock_exit_stack_class, mock_stdio + ): + """Test creating and closing a session in different tasks.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_exit_stack_class.return_value = MockAsyncExitStack() + mock_stdio.return_value = AsyncMock() + + # Mock SessionContext + mock_session_context = AsyncMock() + mock_session_context.__aenter__ = AsyncMock( + return_value=MockClientSession() + ) + mock_session_context.__aexit__ = AsyncMock(return_value=False) + mock_session_context_class.return_value = mock_session_context + + # Create session in a new task + await asyncio.create_task(manager.create_session()) + + # Close session in another task + await asyncio.create_task(manager.close()) + + # Verify session was closed + assert not manager._sessions + @pytest.mark.asyncio async def test_retry_on_errors_decorator(): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 5809efe56f..f6d002ed17 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -17,6 +17,7 @@ import sys import unittest from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import Mock from unittest.mock import patch @@ -28,6 +29,7 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset from mcp import StdioServerParameters import pytest @@ -302,3 +304,52 @@ async def test_get_tools_retry_decorator(self): # Check that the method has the retry decorator assert hasattr(toolset.get_tools, "__wrapped__") + + @pytest.mark.asyncio + async def test_mcp_toolset_with_prefix(self): + """Test that McpToolset correctly applies the tool_name_prefix.""" + # Mock the connection parameters + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + + # Mock the MCPSessionManager and its create_session method + mock_session_manager = MagicMock() + mock_session = MagicMock() + + # Mock the list_tools response from the MCP server + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "tool 2 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1, mock_tool2] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create an instance of McpToolset with a prefix + toolset = McpToolset( + connection_params=mock_connection_params, + tool_name_prefix="my_prefix", + ) + + # Replace the internal session manager with our mock + toolset._mcp_session_manager = mock_session_manager + + # Get the tools from the toolset + tools = await toolset.get_tools() + + # The get_tools method in McpToolset returns MCPTool objects, which are + # instances of BaseTool. The prefixing is handled by the BaseToolset, + # so we need to call get_tools_with_prefix to get the prefixed tools. + prefixed_tools = await toolset.get_tools_with_prefix() + + # Assert that the tools are prefixed correctly + assert len(prefixed_tools) == 2 + assert prefixed_tools[0].name == "my_prefix_tool1" + assert prefixed_tools[1].name == "my_prefix_tool2" + + # Assert that the original tools are not modified + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" diff --git a/tests/unittests/tools/mcp_tool/test_session_context.py b/tests/unittests/tools/mcp_tool/test_session_context.py new file mode 100644 index 0000000000..161cd1aba3 --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_session_context.py @@ -0,0 +1,550 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from datetime import timedelta +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.tools.mcp_tool.session_context import SessionContext +from mcp import ClientSession +import pytest + + +class MockClientSession: + """Mock ClientSession for testing.""" + + def __init__(self, *args, **kwargs): + self._initialized = False + self._args = args + self._kwargs = kwargs + + async def initialize(self): + """Mock initialize method.""" + self._initialized = True + return self + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + +class MockClient: + """Mock MCP client.""" + + def __init__( + self, + transports=None, + raise_on_enter=None, + delay_on_enter=0, + ): + self._transports = transports or ('read_stream', 'write_stream') + self._raise_on_enter = raise_on_enter + self._delay_on_enter = delay_on_enter + self._entered = False + self._exited = False + + async def __aenter__(self): + if self._delay_on_enter > 0: + await asyncio.sleep(self._delay_on_enter) + if self._raise_on_enter: + raise self._raise_on_enter + self._entered = True + return self._transports + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._exited = True + return False + + +class TestSessionContext: + """Test suite for SessionContext class.""" + + @pytest.mark.asyncio + async def test_start_success_ready_event_set_and_session_returned(self): + """Test that start() sets _ready_event and returns session.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock ClientSession + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session = await session_context.start() + + # Verify ready_event was set + assert session_context._ready_event.is_set() + + # Verify session was returned + assert session == mock_session + assert session_context.session == mock_session + + # Verify initialize was called + assert mock_session._initialized + + # Verify task was created and is still running (waiting for close) + assert session_context._task is not None + assert not session_context._task.done() + + # Clean up + await session_context.close() + + @pytest.mark.asyncio + async def test_start_raises_connection_error_on_exception(self): + """Test that start() raises ConnectionError when exception occurs.""" + test_exception = ValueError('Connection failed') + mock_client = MockClient(raise_on_enter=test_exception) + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + # Verify ConnectionError message contains original exception + assert 'Failed to create MCP session' in str(exc_info.value) + assert 'Connection failed' in str(exc_info.value) + + # Verify ready_event was set (in finally block) + assert session_context._ready_event.is_set() + + @pytest.mark.asyncio + async def test_start_raises_connection_error_on_cancelled_error(self): + """Test that start() raises ConnectionError when CancelledError occurs.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock session that will cause cancellation + mock_session = MockClientSession() + + # Make initialize raise CancelledError + async def cancelled_initialize(): + raise asyncio.CancelledError('Task cancelled') + + mock_session.initialize = cancelled_initialize + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + # Should raise ConnectionError (not CancelledError directly) + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + # Verify it's a ConnectionError about cancellation + assert 'Failed to create MCP session' in str(exc_info.value) + assert 'task cancelled' in str(exc_info.value) + + # Verify ready_event was set + assert session_context._ready_event.is_set() + + @pytest.mark.asyncio + async def test_close_cleans_up_task(self): + """Test that close() properly cleans up the task.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock ClientSession + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + # Start the session context + await session_context.start() + + # Verify task is running + assert session_context._task is not None + assert not session_context._task.done() + + # Close the session context + await session_context.close() + + # Wait a bit for cleanup + await asyncio.sleep(0.1) + + # Verify close_event was set + assert session_context._close_event.is_set() + + # Verify task completed (may take a moment) + # The task should finish after close_event is set + assert session_context._task.done() + + @pytest.mark.asyncio + async def test_session_exception_does_not_break_event_loop(self): + """Test that session exceptions don't break the event loop.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Mock ClientSession that raises exception during use + mock_session = MockClientSession() + + async def failing_operation(): + raise RuntimeError('Session operation failed') + + mock_session.failing_operation = failing_operation + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + # Start the session context + session = await session_context.start() + + # Use session and trigger exception + with pytest.raises(RuntimeError, match='Session operation failed'): + await session.failing_operation() + + # Close the session context - should not break event loop + await session_context.close() + + # Verify event loop is still healthy by running another task + result = await asyncio.sleep(0.01) + assert result is None + + @pytest.mark.asyncio + async def test_async_context_manager(self): + """Test using SessionContext as async context manager.""" + mock_client = MockClient() + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + async with SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) as session: + assert session == mock_session + # Verify initialize was called by checking _initialized flag + assert session._initialized + + @pytest.mark.asyncio + async def test_timeout_during_connection(self): + """Test timeout during client connection.""" + # Client that takes longer than timeout + mock_client = MockClient(delay_on_enter=10.0) + session_context = SessionContext( + mock_client, timeout=0.1, sse_read_timeout=None + ) + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + assert 'Failed to create MCP session' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_timeout_during_initialization(self): + """Test timeout during session initialization.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=0.1, sse_read_timeout=None + ) + + # Mock ClientSession with slow initialize + mock_session = MockClientSession() + + async def slow_initialize(): + await asyncio.sleep(1.0) + return mock_session + + mock_session.initialize = slow_initialize + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + assert 'Failed to create MCP session' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_stdio_client_with_read_timeout(self): + """Test stdio client includes read_timeout_seconds parameter.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None, is_stdio=True + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Verify ClientSession was called with read_timeout_seconds for stdio + call_args = mock_session_class.call_args + assert 'read_timeout_seconds' in call_args.kwargs + assert call_args.kwargs['read_timeout_seconds'] == timedelta(seconds=5.0) + + await session_context.close() + + @pytest.mark.asyncio + async def test_non_stdio_client_without_read_timeout(self): + """Test non-stdio client does not include read_timeout_seconds.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None, is_stdio=False + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Verify ClientSession was called with read_timeout_seconds=None for non-stdio + # when sse_read_timeout is None + call_args = mock_session_class.call_args + assert 'read_timeout_seconds' in call_args.kwargs + assert call_args.kwargs['read_timeout_seconds'] is None + + await session_context.close() + + @pytest.mark.asyncio + async def test_sse_read_timeout_passed_to_client_session(self): + """Test that sse_read_timeout is passed to ClientSession for non-stdio.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=300.0, is_stdio=False + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Verify ClientSession was called with sse_read_timeout + call_args = mock_session_class.call_args + assert 'read_timeout_seconds' in call_args.kwargs + assert call_args.kwargs['read_timeout_seconds'] == timedelta( + seconds=300.0 + ) + + await session_context.close() + + @pytest.mark.asyncio + async def test_close_multiple_times(self): + """Test that close() can be called multiple times safely.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Close multiple times + await session_context.close() + await session_context.close() + await session_context.close() + + # Should not raise exception + assert session_context._close_event.is_set() + + @pytest.mark.asyncio + async def test_close_before_start(self): + """Test that close() works even if start() was never called.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Close before starting should not raise + await session_context.close() + + assert session_context._close_event.is_set() + + @pytest.mark.asyncio + async def test_close_before_start_ends(self): + """Test that close() before start() ends the task.""" + # Client has enough time to delay the start task + mock_client = MockClient(delay_on_enter=10.0) + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + start_task = asyncio.create_task(session_context.start()) + await asyncio.sleep(0.1) + assert not start_task.done() + + # Call close before start() ends the task + await session_context.close() + await asyncio.sleep(0.1) + + assert start_task.done() + assert isinstance( + start_task.exception(), ConnectionError + ) and 'task cancelled' in str(start_task.exception()) + + @pytest.mark.asyncio + async def test_close_before_start_called(self): + """Test that close() before start() called sets the close event.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Call close() before start() called + await session_context.close() + await asyncio.sleep(0.1) + + assert session_context._task is None + assert session_context._close_event.is_set() + + with pytest.raises(ConnectionError) as exc_info: + await session_context.start() + + assert 'session already closed' in str(exc_info.value) + assert session_context._task is None + + @pytest.mark.asyncio + async def test_session_property(self): + """Test that session property returns the managed session.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Initially None + assert session_context.session is None + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Should return the session + assert session_context.session == mock_session + + await session_context.close() + + @pytest.mark.asyncio + async def test_client_cleanup_on_exception(self): + """Test that client is properly cleaned up even when exception occurs.""" + test_exception = RuntimeError('Test error') + mock_client = MockClient(raise_on_enter=test_exception) + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + with pytest.raises(ConnectionError): + await session_context.start() + + # Wait a bit for cleanup + await asyncio.sleep(0.1) + + # Verify task completed + assert session_context._task.done() + + @pytest.mark.asyncio + async def test_close_handles_cancelled_error(self): + """Test that close() handles CancelledError gracefully.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + await session_context.start() + + # Cancel the task + if session_context._task: + session_context._task.cancel() + + # Close should handle CancelledError gracefully + await session_context.close() + + # Should not raise exception + assert session_context._close_event.is_set() + + @pytest.mark.asyncio + async def test_close_handles_exception_during_cleanup(self): + """Test that close() handles exceptions during cleanup gracefully.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + # Create a mock session that raises during exit + class FailingMockSession(MockClientSession): + + async def __aexit__(self, exc_type, exc_val, exc_tb): + raise RuntimeError('Cleanup failed') + + failing_session = FailingMockSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = failing_session + + await session_context.start() + + # Close should handle the exception gracefully + await session_context.close() + + # Should not raise exception + assert session_context._close_event.is_set() diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py index 55454224d8..5f7bf53f7d 100644 --- a/tests/unittests/tools/test_base_authenticated_tool.py +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -90,6 +90,7 @@ def test_init_with_auth_config(self): assert tool.description == "Test description" assert tool._credentials_manager is not None assert tool._response_for_auth_required == unauthenticated_response + assert tool._auth_config == auth_config def test_init_with_no_auth_config(self): """Test initialization without auth_config.""" @@ -99,6 +100,7 @@ def test_init_with_no_auth_config(self): assert tool.description == "Test authenticated tool" assert tool._credentials_manager is None assert tool._response_for_auth_required is None + assert tool._auth_config is None def test_init_with_empty_auth_scheme(self): """Test initialization with auth_config but no auth_scheme.""" diff --git a/tests/unittests/tools/test_from_function_with_options.py b/tests/unittests/tools/test_from_function_with_options.py index 61670a2678..eae164538f 100644 --- a/tests/unittests/tools/test_from_function_with_options.py +++ b/tests/unittests/tools/test_from_function_with_options.py @@ -14,7 +14,9 @@ from collections.abc import Sequence from typing import Any +from typing import AsyncGenerator from typing import Dict +from typing import Generator from google.adk.tools import _automatic_function_calling_util from google.adk.utils.variant_utils import GoogleLLMVariant @@ -242,3 +244,78 @@ def test_function( assert declaration.name == 'test_function' assert declaration.response.type == types.Type.ARRAY assert declaration.response.items.type == types.Type.STRING + + +def test_from_function_with_async_generator_return_vertex(): + """Test from_function_with_options with AsyncGenerator return for VERTEX_AI.""" + + async def test_function(param: str) -> AsyncGenerator[str, None]: + """A streaming function that yields strings.""" + yield param + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (str) from AsyncGenerator[str, None] + assert declaration.response is not None + assert declaration.response.type == types.Type.STRING + + +def test_from_function_with_async_generator_return_gemini(): + """Test from_function_with_options with AsyncGenerator return for GEMINI_API.""" + + async def test_function(param: str) -> AsyncGenerator[str, None]: + """A streaming function that yields strings.""" + yield param + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.GEMINI_API + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # GEMINI_API should not have response schema + assert declaration.response is None + + +def test_from_function_with_generator_return_vertex(): + """Test from_function_with_options with Generator return for VERTEX_AI.""" + + def test_function(param: str) -> Generator[int, None, None]: + """A streaming function that yields integers.""" + yield 42 + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (int) from Generator[int, None, None] + assert declaration.response is not None + assert declaration.response.type == types.Type.INTEGER + + +def test_from_function_with_async_generator_complex_yield_type_vertex(): + """Test from_function_with_options with AsyncGenerator yielding dict.""" + + async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]: + """A streaming function that yields dicts.""" + yield {'result': param} + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (Dict[str, str]) from AsyncGenerator + assert declaration.response is not None + assert declaration.response.type == types.Type.OBJECT diff --git a/tests/unittests/tools/test_load_artifacts_tool.py b/tests/unittests/tools/test_load_artifacts_tool.py new file mode 100644 index 0000000000..1ea50bb33c --- /dev/null +++ b/tests/unittests/tools/test_load_artifacts_tool.py @@ -0,0 +1,162 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 + +from google.adk.models.llm_request import LlmRequest +from google.adk.tools.load_artifacts_tool import _maybe_base64_to_bytes +from google.adk.tools.load_artifacts_tool import load_artifacts_tool +from google.genai import types +from pytest import mark + + +class _StubToolContext: + """Minimal ToolContext stub for LoadArtifactsTool tests.""" + + def __init__(self, artifacts_by_name: dict[str, types.Part]): + self._artifacts_by_name = artifacts_by_name + + async def list_artifacts(self) -> list[str]: + return list(self._artifacts_by_name.keys()) + + async def load_artifact(self, name: str) -> types.Part | None: + return self._artifacts_by_name.get(name) + + +@mark.asyncio +async def test_load_artifacts_converts_unsupported_mime_to_text(): + """Unsupported inline MIME types are converted to text parts.""" + artifact_name = 'test.csv' + csv_bytes = b'col1,col2\n1,2\n' + artifact = types.Part( + inline_data=types.Blob(data=csv_bytes, mime_type='application/csv') + ) + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ) + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.contents[-1].parts[0].text == ( + f'Artifact {artifact_name} is:' + ) + artifact_part = llm_request.contents[-1].parts[1] + assert artifact_part.inline_data is None + assert artifact_part.text == csv_bytes.decode('utf-8') + + +@mark.asyncio +async def test_load_artifacts_converts_base64_unsupported_mime_to_text(): + """Unsupported base64 string data is converted to text parts.""" + artifact_name = 'test.csv' + csv_bytes = b'col1,col2\n1,2\n' + csv_base64 = base64.b64encode(csv_bytes).decode('ascii') + artifact = types.Part( + inline_data=types.Blob(data=csv_base64, mime_type='application/csv') + ) + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ) + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + artifact_part = llm_request.contents[-1].parts[1] + assert artifact_part.inline_data is None + assert artifact_part.text == csv_bytes.decode('utf-8') + + +@mark.asyncio +async def test_load_artifacts_keeps_supported_mime_types(): + """Supported inline MIME types are passed through unchanged.""" + artifact_name = 'test.pdf' + artifact = types.Part( + inline_data=types.Blob(data=b'%PDF-1.4', mime_type='application/pdf') + ) + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ) + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + artifact_part = llm_request.contents[-1].parts[1] + assert artifact_part.inline_data is not None + assert artifact_part.inline_data.mime_type == 'application/pdf' + + +def test_maybe_base64_to_bytes_decodes_standard_base64(): + """Standard base64 encoded strings are decoded correctly.""" + original = b'hello world' + encoded = base64.b64encode(original).decode('ascii') + assert _maybe_base64_to_bytes(encoded) == original + + +def test_maybe_base64_to_bytes_decodes_urlsafe_base64(): + """URL-safe base64 encoded strings are decoded correctly.""" + original = b'\xfb\xff\xfe' # bytes that produce +/ in std but -_ in urlsafe + encoded = base64.urlsafe_b64encode(original).decode('ascii') + assert _maybe_base64_to_bytes(encoded) == original + + +def test_maybe_base64_to_bytes_returns_none_for_invalid(): + """Invalid base64 strings return None.""" + # Single character is invalid (base64 requires length % 4 == 0 after padding) + assert _maybe_base64_to_bytes('x') is None diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py deleted file mode 100644 index 7bfd912669..0000000000 --- a/tests/unittests/tools/test_mcp_toolset.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for McpToolset.""" - -from unittest.mock import AsyncMock -from unittest.mock import MagicMock - -from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -import pytest - - -@pytest.mark.asyncio -async def test_mcp_toolset_with_prefix(): - """Test that McpToolset correctly applies the tool_name_prefix.""" - # Mock the connection parameters - mock_connection_params = MagicMock() - mock_connection_params.timeout = None - - # Mock the MCPSessionManager and its create_session method - mock_session_manager = MagicMock() - mock_session = MagicMock() - - # Mock the list_tools response from the MCP server - mock_tool1 = MagicMock() - mock_tool1.name = "tool1" - mock_tool1.description = "tool 1 desc" - mock_tool2 = MagicMock() - mock_tool2.name = "tool2" - mock_tool2.description = "tool 2 desc" - list_tools_result = MagicMock() - list_tools_result.tools = [mock_tool1, mock_tool2] - mock_session.list_tools = AsyncMock(return_value=list_tools_result) - mock_session_manager.create_session = AsyncMock(return_value=mock_session) - - # Create an instance of McpToolset with a prefix - toolset = McpToolset( - connection_params=mock_connection_params, - tool_name_prefix="my_prefix", - ) - - # Replace the internal session manager with our mock - toolset._mcp_session_manager = mock_session_manager - - # Get the tools from the toolset - tools = await toolset.get_tools() - - # The get_tools method in McpToolset returns MCPTool objects, which are - # instances of BaseTool. The prefixing is handled by the BaseToolset, - # so we need to call get_tools_with_prefix to get the prefixed tools. - prefixed_tools = await toolset.get_tools_with_prefix() - - # Assert that the tools are prefixed correctly - assert len(prefixed_tools) == 2 - assert prefixed_tools[0].name == "my_prefix_tool1" - assert prefixed_tools[1].name == "my_prefix_tool2" - - # Assert that the original tools are not modified - assert tools[0].name == "tool1" - assert tools[1].name == "tool2"