From f0769777222cfe5b88aeaea9ebca6cdc7dd78ba4 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 14:35:05 -0300 Subject: [PATCH 01/10] feat: Refactor BQ Analytics Plugin to use Structured JSON --- .../bigquery_agent_analytics_plugin.py | 667 ++++++++---------- .../test_bigquery_agent_analytics_plugin.py | 613 +++++++--------- 2 files changed, 573 insertions(+), 707 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 63b95e57ea..fc5d010281 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -85,7 +85,7 @@ def _pyarrow_timestamp(): "GEOGRAPHY": pa.string, "INT64": pa.int64, "INTEGER": pa.int64, - "JSON": pa.string, + "JSON": pa.string, # JSON is passed as string to Arrow "NUMERIC": _pyarrow_numeric, "BIGNUMERIC": _pyarrow_bignumeric, "STRING": pa.string, @@ -221,13 +221,39 @@ class BigQueryLoggerConfig: enabled: bool = True event_allowlist: Optional[List[str]] = None event_denylist: Optional[List[str]] = None + # Custom formatter is discouraged now that we use JSON, but kept for compat content_formatter: Optional[Callable[[Any], str]] = None shutdown_timeout: float = 5.0 client_close_timeout: float = 2.0 - max_content_length: int = 500 + # Increased default limit to 50KB since we truncate per-field, not per-row + max_content_length: int = 50000 + + +def _recursive_smart_truncate(obj: Any, max_len: int) -> Any: + """Recursively truncates string values within a dict or list.""" + if isinstance(obj, str): + if len(obj) > max_len: + return obj[:max_len] + "...[TRUNCATED]" + return obj + elif isinstance(obj, dict): + return {k: _recursive_smart_truncate(v, max_len) for k, v in obj.items()} + elif isinstance(obj, list): + return [_recursive_smart_truncate(i, max_len) for i in obj] + else: + return obj + + +def _serialize_to_json_safe(content_obj: Any, max_len: int) -> str: + """Safely serializes an object to a JSON string with smart truncation.""" + try: + truncated_obj = _recursive_smart_truncate(content_obj, max_len) + # default=str handles datetime or other non-serializable types by converting to string + return json.dumps(truncated_obj, default=str) + except Exception as e: + logging.warning(f"JSON serialization failed: {e}") + return json.dumps({"error": "Serialization failed", "details": str(e)}) -# --- Helper Formatters --- def _get_event_type(event: Event) -> str: """Determines the event type from an Event object.""" if event.author == "user": @@ -243,66 +269,8 @@ def _get_event_type(event: Event) -> str: return "SYSTEM" -def _format_content( - content: Optional[types.Content], max_len: int = 500 -) -> tuple[str, bool]: - """Formats an Event content for logging. - - Args: - content: The Event content to format. - max_len: The maximum length of the text parts before truncation. - - Returns: - A tuple containing the formatted content string and a boolean indicating if - the content was truncated. - """ - if not content or not content.parts: - return "None", False - parts = [] - for p in content.parts: - if p.text: - parts.append( - f"text: '{p.text[:max_len]}...' " - if len(p.text) > max_len - else f"text: '{p.text}'" - ) - elif p.function_call: - parts.append(f"call: {p.function_call.name}") - elif p.function_response: - parts.append(f"resp: {p.function_response.name}") - else: - parts.append("other") - return " | ".join(parts), any( - len(p.text) > max_len for p in content.parts if p.text - ) - - -def _format_args( - args: dict[str, Any], *, max_len: int = 1000 -) -> tuple[str, bool]: - """Formats tool arguments or results for logging. - - Args: - args: The tool arguments or results dictionary to format. - max_len: The maximum length of the output string before truncation. - - Returns: - A tuple containing the JSON formatted string and a boolean indicating if - the content was truncated. - """ - if not args: - return "{}", False - try: - s = json.dumps(args) - except TypeError: - s = str(args) - if len(s) > max_len: - return s[:max_len] + "...", True - return s, False - - class BigQueryAgentAnalyticsPlugin(BasePlugin): - """A plugin that logs agent analytic events to Google BigQuery. + """A plugin that logs agent analytic events to Google BigQuery (Structured JSON). This plugin captures key events during an agent's lifecycle—such as user interactions, tool executions, LLM requests/responses, and errors—and @@ -345,6 +313,8 @@ def __init__( self._arrow_schema: pa.Schema | None = None self._background_tasks: set[asyncio.Task] = set() self._is_shutting_down = False + + # --- Updated Schema: Content is now JSON --- self._schema = [ bigquery.SchemaField( "timestamp", @@ -356,90 +326,47 @@ def __init__( "event_type", "STRING", mode="NULLABLE", - description=( - "Indicates the type of event being logged (e.g., 'LLM_REQUEST'," - " 'TOOL_COMPLETED')." - ), + description="Indicates the type of event (e.g., 'LLM_REQUEST').", ), bigquery.SchemaField( "agent", "STRING", mode="NULLABLE", - description=( - "The name of the ADK agent or author associated with the event." - ), + description="The name of the ADK agent.", ), bigquery.SchemaField( "session_id", "STRING", mode="NULLABLE", - description=( - "A unique identifier to group events within a single" - " conversation or user session." - ), + description="Unique identifier for the session.", ), bigquery.SchemaField( "invocation_id", "STRING", mode="NULLABLE", - description=( - "A unique identifier for each individual agent execution or" - " turn within a session." - ), + description="Unique identifier for the invocation/turn.", ), bigquery.SchemaField( "user_id", "STRING", mode="NULLABLE", - description=( - "The identifier of the user associated with the current" - " session." - ), + description="The user identifier.", ), + # CHANGED: STRING -> JSON bigquery.SchemaField( "content", - "STRING", + "JSON", mode="NULLABLE", - description=( - "The event-specific data (payload). Format varies by" - " event_type." - ), + description="Structured event payload.", ), bigquery.SchemaField( "error_message", "STRING", mode="NULLABLE", - description=( - "Populated if an error occurs during the processing of the" - " event." - ), - ), - bigquery.SchemaField( - "is_truncated", - "BOOLEAN", - mode="NULLABLE", - description=( - "Indicates if the content field was truncated due to size" - " limits." - ), + description="Error details if applicable.", ), ] - def _format_content_safely( - self, content: Optional[types.Content] - ) -> tuple[str | None, bool]: - """Formats content using self._config.content_formatter or _format_content, catching errors.""" - if content is None: - return None, False - try: - if self._config.content_formatter: - # Custom formatter: we assume no truncation or we can't know. - return self._config.content_formatter(content), False - return _format_content(content, max_len=self._config.max_content_length) - except Exception as e: - logging.warning("Content formatter failed: %s", e) - return "[FORMATTING FAILED]", False - async def _ensure_init(self): """Ensures BigQuery clients are initialized.""" if self._write_client: @@ -461,7 +388,6 @@ async def _ensure_init(self): project=self._project_id, credentials=creds, client_info=client_info ) - # Ensure table exists (sync call in thread) def create_resources(): if self._bq_client: self._bq_client.create_dataset(self._dataset_id, exists_ok=True) @@ -489,14 +415,13 @@ def create_resources(): self._arrow_schema = to_arrow_schema(self._schema) if not self._arrow_schema: raise RuntimeError("Failed to convert BigQuery schema to Arrow.") - logging.info("BQ Plugin: Initialized successfully.") return True except Exception as e: logging.error("BQ Plugin: Init Failed:", exc_info=True) return False async def _perform_write(self, row: dict): - """Actual async write operation, intended to run as a background task.""" + """Actual async write operation.""" try: if ( not await self._ensure_init() @@ -505,7 +430,6 @@ async def _perform_write(self, row: dict): ): return - # Serialize pydict = {f.name: [row.get(f.name)] for f in self._arrow_schema} batch = pa.RecordBatch.from_pydict(pydict, schema=self._arrow_schema) req = bq_storage_types.AppendRowsRequest( @@ -518,22 +442,16 @@ async def _perform_write(self, row: dict): batch.serialize().to_pybytes() ) - # Write with protection against immediate cancellation async for resp in await asyncio.shield( self._write_client.append_rows(iter([req])) ): if resp.error.code != 0: msg = resp.error.message - # Check for common schema mismatch indicators - if ( - "schema mismatch" in msg.lower() - or "field" in msg.lower() - or "type" in msg.lower() - ): + if "schema mismatch" in msg.lower(): logging.error( - "BQ Plugin: Schema Mismatch Error. The BigQuery table schema" - " may be incorrect or out of sync with the plugin. Please" - " verify the table definition. Details: %s", + "BQ Plugin: Schema Mismatch. You may need to delete the" + " existing table if you migrated from STRING content to JSON" + " content. Details: %s", msg, ) else: @@ -545,13 +463,19 @@ async def _perform_write(self, row: dict): except asyncio.CancelledError: if not self._is_shutting_down: logging.warning("BQ Plugin: Write task cancelled unexpectedly.") - except Exception as e: + except Exception: logging.error("BQ Plugin: Write Failed:", exc_info=True) - async def _log(self, data: dict): - """Schedules a log entry to be written in the background.""" + async def _log(self, data: dict, content_payload: Any = None): + """ + Schedules a log entry. + Args: + data: Metadata dict (event_type, agent, etc.) + content_payload: The structured data to be JSON serialized. + """ if not self._config.enabled: return + event_type = data.get("event_type") if ( self._config.event_denylist @@ -564,7 +488,24 @@ async def _log(self, data: dict): ): return - # Prepare row immediately (capture current state) + # If a custom formatter/redactor is provided, let it modify the payload + # BEFORE we truncate and serialize it. + if self._config.content_formatter and content_payload is not None: + try: + # The formatter now receives a Dict and should return a Dict + content_payload = self._config.content_formatter(content_payload) + except Exception as e: + logging.warning(f"Content formatter failed: {e}") + # Fallback: keep original payload but log the error + + # Prepare payload + content_json_str = None + if content_payload is not None: + # Use smart truncation to keep JSON valid but safe size + content_json_str = _serialize_to_json_safe( + content_payload, self._config.max_content_length + ) + row = { "timestamp": datetime.now(timezone.utc), "event_type": None, @@ -572,13 +513,11 @@ async def _log(self, data: dict): "session_id": None, "invocation_id": None, "user_id": None, - "content": None, + "content": content_json_str, # Injected here "error_message": None, - "is_truncated": False, } row.update(data) - # Fire and forget: Create task and track it task = asyncio.create_task(self._perform_write(row)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -603,7 +542,6 @@ async def close(self): except Exception as e: logging.warning("BQ Plugin: Error flushing logs:", exc_info=True) - # Use getattr for safe access in case transport is not present. if self._write_client and getattr(self._write_client, "transport", None): try: logging.info("BQ Plugin: Closing write client.") @@ -613,6 +551,7 @@ async def close(self): ) except Exception as e: logging.warning("BQ Plugin: Error closing write client: %s", e) + pass if self._bq_client: try: self._bq_client.close() @@ -624,7 +563,8 @@ async def close(self): self._is_shutting_down = False logging.info("BQ Plugin: Shutdown complete.") - # --- Streamlined Callbacks --- + # --- Refactored Callbacks using Structured Data --- + async def on_user_message_callback( self, *, @@ -636,19 +576,27 @@ async def on_user_message_callback( Logs the user message details including: 1. User content (text) - The content is formatted as 'User Content: {content}'. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing the user text. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - content, truncated = self._format_content_safely(user_message) - await self._log({ - "event_type": "USER_MESSAGE_RECEIVED", - "agent": invocation_context.agent.name, - "session_id": invocation_context.session.id, - "invocation_id": invocation_context.invocation_id, - "user_id": invocation_context.session.user_id, - "content": f"User Content: {content}", - "is_truncated": truncated, - }) + # Extract text parts + text_content = "" + if user_message and user_message.parts: + text_content = " ".join([p.text for p in user_message.parts if p.text]) + + payload = {"text": text_content} + + await self._log( + { + "event_type": "USER_MESSAGE_RECEIVED", + "agent": invocation_context.agent.name, + "session_id": invocation_context.session.id, + "invocation_id": invocation_context.invocation_id, + "user_id": invocation_context.session.user_id, + }, + content_payload=payload, + ) async def before_run_callback( self, *, invocation_context: InvocationContext @@ -665,7 +613,7 @@ async def before_run_callback( "session_id": invocation_context.session.id, "invocation_id": invocation_context.invocation_id, "user_id": invocation_context.session.user_id, - }) + }) # No content payload needed async def on_event_callback( self, *, invocation_context: InvocationContext, event: Event @@ -677,21 +625,42 @@ async def on_event_callback( 2. Event content (text, function calls, or responses) 3. Error messages (if any) - The content is formatted based on the event type. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object based on the event type. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - content, truncated = self._format_content_safely(event.content) - await self._log({ - "event_type": _get_event_type(event), - "agent": event.author, - "session_id": invocation_context.session.id, - "invocation_id": invocation_context.invocation_id, - "user_id": invocation_context.session.user_id, - "content": content, - "error_message": event.error_message, - "timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc), - "is_truncated": truncated, - }) + # We try to extract text, but keep it simple for generic events + text_parts = [] + tool_calls = [] + tool_responses = [] + + if event.content and event.content.parts: + for p in event.content.parts: + if p.text: + text_parts.append(p.text) + if p.function_call: + tool_calls.append(p.function_call.name) + if p.function_response: + tool_responses.append(p.function_response.name) + + payload = { + "text": " ".join(text_parts) if text_parts else None, + "tool_calls": tool_calls if tool_calls else None, + "tool_responses": tool_responses if tool_responses else None, + "raw_role": event.author, + } + + await self._log( + { + "event_type": _get_event_type(event), + "agent": event.author, + "session_id": invocation_context.session.id, + "invocation_id": invocation_context.invocation_id, + "user_id": invocation_context.session.user_id, + "error_message": event.error_message, + }, + content_payload=payload, + ) async def after_run_callback( self, *, invocation_context: InvocationContext @@ -719,14 +688,16 @@ async def before_agent_callback( Content includes: 1. Agent Name (from callback context) """ - await self._log({ - "event_type": "AGENT_STARTING", - "agent": agent.name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": f"Agent Name: {callback_context.agent_name}", - }) + await self._log( + { + "event_type": "AGENT_STARTING", + "agent": agent.name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + }, + content_payload={"target_agent": callback_context.agent_name}, + ) async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext @@ -737,14 +708,16 @@ async def after_agent_callback( Content includes: 1. Agent Name (from callback context) """ - await self._log({ - "event_type": "AGENT_COMPLETED", - "agent": agent.name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": f"Agent Name: {callback_context.agent_name}", - }) + await self._log( + { + "event_type": "AGENT_COMPLETED", + "agent": agent.name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + }, + content_payload={"target_agent": callback_context.agent_name}, + ) async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -758,92 +731,76 @@ async def before_model_callback( 4. Prompt content (user/model messages) 5. System instructions - The content is formatted as a single string with fields separated by ' | '. - If the total length exceeds `max_content_length`, the string is truncated, - prioritizing the metadata (Model, Params, Tools) over the Prompt and System - Prompt. + The content is formatted as a structured JSON object. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - content_parts = [ - f"Model: {llm_request.model or 'default'}", - ] - is_truncated = False - # 1. Params + # 1. Config Params + params = {} if llm_request.config: - config = llm_request.config - params_to_log = {} - if hasattr(config, "temperature") and config.temperature is not None: - params_to_log["temperature"] = config.temperature - if hasattr(config, "top_p") and config.top_p is not None: - params_to_log["top_p"] = config.top_p - if hasattr(config, "top_k") and config.top_k is not None: - params_to_log["top_k"] = config.top_k - if ( - hasattr(config, "max_output_tokens") - and config.max_output_tokens is not None - ): - params_to_log["max_output_tokens"] = config.max_output_tokens - - if params_to_log: - params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()]) - content_parts.append(f"Params: {{{params_str}}}") - - # 2. Tools - if llm_request.tools_dict: - content_parts.append( - f"Available Tools: {list(llm_request.tools_dict.keys())}" - ) - - # 3. Prompt - if contents := getattr(llm_request, "contents", None): - prompt_parts = [] - for c in contents: - c_str, c_trunc = self._format_content_safely(c) - prompt_parts.append(f"{c.role}: {c_str}") - if c_trunc: - is_truncated = True - prompt_str = " | ".join(prompt_parts) - content_parts.append(f"Prompt: {prompt_str}") - - # 4. System Prompt - system_instruction_text = "None" + cfg = llm_request.config + if getattr(cfg, "temperature", None) is not None: + params["temperature"] = cfg.temperature + if getattr(cfg, "top_p", None) is not None: + params["top_p"] = cfg.top_p + if getattr(cfg, "top_k", None) is not None: + params["top_k"] = cfg.top_k + if getattr(cfg, "max_output_tokens", None) is not None: + params["max_output_tokens"] = cfg.max_output_tokens + + # 2. System Instruction + system_instr = "None" if llm_request.config and llm_request.config.system_instruction: si = llm_request.config.system_instruction if isinstance(si, str): - system_instruction_text = si + system_instr = si elif isinstance(si, types.Content): - system_instruction_text = "".join(p.text for p in si.parts if p.text) + system_instr = "".join(p.text for p in si.parts if p.text) elif isinstance(si, types.Part): - system_instruction_text = si.text - elif hasattr(si, "__iter__"): - texts = [] - for item in si: - if isinstance(item, str): - texts.append(item) - elif isinstance(item, types.Part) and item.text: - texts.append(item.text) - system_instruction_text = "".join(texts) - else: - system_instruction_text = str(si) - elif llm_request.config and not llm_request.config.system_instruction: - system_instruction_text = "Empty" - - content_parts.append(f"System Prompt: {system_instruction_text}") - - final_content = " | ".join(content_parts) - max_len = self._config.max_content_length - if len(final_content) > max_len: - final_content = final_content[:max_len] + "..." - is_truncated = True - await self._log({ - "event_type": "LLM_REQUEST", - "agent": callback_context.agent_name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": final_content, - "is_truncated": is_truncated, - }) + system_instr = si.text + + # 3. Prompt History (Simplified structure for JSON) + prompt_history = [] + if getattr(llm_request, "contents", None): + for c in llm_request.contents: + role = c.role + parts_list = [] + for p in c.parts: + if p.text: + parts_list.append({"type": "text", "text": p.text}) + elif p.function_call: + parts_list.append( + {"type": "function_call", "name": p.function_call.name} + ) + elif p.function_response: + parts_list.append( + {"type": "function_response", "name": p.function_response.name} + ) + prompt_history.append({"role": role, "parts": parts_list}) + + payload = { + "model": llm_request.model or "default", + "params": params, + "tools_available": ( + list(llm_request.tools_dict.keys()) + if llm_request.tools_dict + else [] + ), + "system_instruction": system_instr, + "prompt": prompt_history, + } + + await self._log( + { + "event_type": "LLM_REQUEST", + "agent": callback_context.agent_name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + }, + content_payload=payload, + ) async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse @@ -855,60 +812,50 @@ async def after_model_callback( 2. Text response (if no tool calls) 3. Token usage statistics (prompt, candidates, total) - The content is formatted as a single string with fields separated by ' | '. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing response parts + and usage statistics. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ content_parts = [] - content = llm_response.content - is_tool_call = False - is_truncated = False - if content and content.parts: - is_tool_call = any(part.function_call for part in content.parts) - - if is_tool_call: - fc_names = [] - if content and content.parts: - fc_names = [ - part.function_call.name - for part in content.parts - if part.function_call - ] - content_parts.append(f"Tool Name: {', '.join(fc_names)}") - else: - text_content, truncated = self._format_content_safely( - llm_response.content - ) - content_parts.append(f"Tool Name: text_response, {text_content}") - if truncated: - is_truncated = True - + if llm_response.content and llm_response.content.parts: + for p in llm_response.content.parts: + if p.text: + content_parts.append({"type": "text", "text": p.text}) + if p.function_call: + content_parts.append({ + "type": "function_call", + "name": p.function_call.name, + "args": dict(p.function_call.args), + }) + + usage = {} if llm_response.usage_metadata: - prompt_tokens = getattr( - llm_response.usage_metadata, "prompt_token_count", "N/A" - ) - candidates_tokens = getattr( - llm_response.usage_metadata, "candidates_token_count", "N/A" - ) - total_tokens = getattr( - llm_response.usage_metadata, "total_token_count", "N/A" - ) - token_usage_str = ( - f"Token Usage: {{prompt: {prompt_tokens}, candidates:" - f" {candidates_tokens}, total: {total_tokens}}}" - ) - content_parts.append(token_usage_str) - - final_content = " | ".join(content_parts) - await self._log({ - "event_type": "LLM_RESPONSE", - "agent": callback_context.agent_name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": final_content, - "error_message": llm_response.error_message, - "is_truncated": is_truncated, - }) + usage = { + "prompt_tokens": getattr( + llm_response.usage_metadata, "prompt_token_count", 0 + ), + "candidates_tokens": getattr( + llm_response.usage_metadata, "candidates_token_count", 0 + ), + "total_tokens": getattr( + llm_response.usage_metadata, "total_token_count", 0 + ), + } + + payload = {"response_content": content_parts, "usage": usage} + + await self._log( + { + "event_type": "LLM_RESPONSE", + "agent": callback_context.agent_name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + "error_message": llm_response.error_message, + }, + content_payload=payload, + ) async def before_tool_callback( self, @@ -924,29 +871,26 @@ async def before_tool_callback( 2. Tool description 3. Tool arguments - The content is formatted as 'Tool Name: ..., Description: ..., Arguments: - ...'. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing tool name, + description, and arguments. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - args_str, truncated = _format_args( - tool_args, max_len=self._config.max_content_length - ) - content = ( - f"Tool Name: {tool.name}, Description: {tool.description}," - f" Arguments: {args_str}" + payload = { + "tool_name": tool.name, + "description": tool.description, + "arguments": tool_args, + } + await self._log( + { + "event_type": "TOOL_STARTING", + "agent": tool_context.agent_name, + "session_id": tool_context.session.id, + "invocation_id": tool_context.invocation_id, + "user_id": tool_context.session.user_id, + }, + content_payload=payload, ) - if len(content) > self._config.max_content_length: - content = content[: self._config.max_content_length] + "..." - truncated = True - await self._log({ - "event_type": "TOOL_STARTING", - "agent": tool_context.agent_name, - "session_id": tool_context.session.id, - "invocation_id": tool_context.invocation_id, - "user_id": tool_context.session.user_id, - "content": content, - "is_truncated": truncated, - }) async def after_tool_callback( self, @@ -962,25 +906,21 @@ async def after_tool_callback( 1. Tool name 2. Tool result - The content is formatted as 'Tool Name: ..., Result: ...'. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing tool name and result. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - result_str, truncated = _format_args( - result, max_len=self._config.max_content_length + payload = {"tool_name": tool.name, "result": result} + await self._log( + { + "event_type": "TOOL_COMPLETED", + "agent": tool_context.agent_name, + "session_id": tool_context.session.id, + "invocation_id": tool_context.invocation_id, + "user_id": tool_context.session.user_id, + }, + content_payload=payload, ) - content = f"Tool Name: {tool.name}, Result: {result_str}" - if len(content) > self._config.max_content_length: - content = content[: self._config.max_content_length] + "..." - truncated = True - await self._log({ - "event_type": "TOOL_COMPLETED", - "agent": tool_context.agent_name, - "session_id": tool_context.session.id, - "invocation_id": tool_context.invocation_id, - "user_id": tool_context.session.user_id, - "content": content, - "is_truncated": truncated, - }) async def on_model_error_callback( self, @@ -1019,23 +959,20 @@ async def on_tool_error_callback( 1. Tool name 2. Tool arguments + The content is formatted as a structured JSON object containing tool name and arguments. The error message is captured in the `error_message` field. - If the content length exceeds `max_content_length`, it is truncated. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - args_str, truncated = _format_args( - tool_args, max_len=self._config.max_content_length + payload = {"tool_name": tool.name, "arguments": tool_args} + await self._log( + { + "event_type": "TOOL_ERROR", + "agent": tool_context.agent_name, + "session_id": tool_context.session.id, + "invocation_id": tool_context.invocation_id, + "user_id": tool_context.session.user_id, + "error_message": str(error), + }, + content_payload=payload, ) - content = f"Tool Name: {tool.name}, Arguments: {args_str}" - if len(content) > self._config.max_content_length: - content = content[: self._config.max_content_length] + "..." - truncated = True - await self._log({ - "event_type": "TOOL_ERROR", - "agent": tool_context.agent_name, - "session_id": tool_context.session.id, - "invocation_id": tool_context.invocation_id, - "user_id": tool_context.session.user_id, - "content": content, - "error_message": str(error), - "is_truncated": truncated, - }) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 6f0412dbbd..0dd3e1617c 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -34,7 +34,6 @@ from google.auth import exceptions as auth_exceptions import google.auth.credentials from google.cloud import bigquery -from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.genai import types import pyarrow as pa import pytest @@ -136,7 +135,6 @@ async def fake_append_rows(requests, **kwargs): mock_append_rows_response.row_errors = [] mock_append_rows_response.error = mock.MagicMock() mock_append_rows_response.error.code = 0 # OK status - # This a gen is what's returned *after* the await. return _async_gen(mock_append_rows_response) mock_client.append_rows.side_effect = fake_append_rows @@ -145,6 +143,7 @@ async def fake_append_rows(requests, **kwargs): @pytest.fixture def dummy_arrow_schema(): + # UPDATED: content is pa.string() because JSON is serialized to string before Arrow return pa.schema([ pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False), pa.field("event_type", pa.string(), nullable=True), @@ -154,7 +153,6 @@ def dummy_arrow_schema(): pa.field("user_id", pa.string(), nullable=True), pa.field("content", pa.string(), nullable=True), pa.field("error_message", pa.string(), nullable=True), - pa.field("is_truncated", pa.bool_(), nullable=True), ]) @@ -234,7 +232,6 @@ def _assert_common_fields(log_entry, event_type, agent="MyTestAgent"): assert log_entry["user_id"] == "user-456" assert "timestamp" in log_entry assert isinstance(log_entry["timestamp"], datetime.datetime) - assert "is_truncated" in log_entry # --- Test Class --- @@ -257,7 +254,6 @@ async def test_plugin_disabled( table_id=TABLE_ID, config=config, ) - # user_message = types.Content(parts=[types.Part(text="Test")]) await plugin.on_user_message_callback( invocation_context=invocation_context, @@ -334,67 +330,59 @@ async def test_event_denylist( mock_write_client.append_rows.assert_called_once() @pytest.mark.asyncio - async def test_content_formatter( + async def test_content_formatter_payload_mutation( self, mock_write_client, - invocation_context, + callback_context, mock_auth_default, mock_bq_client, mock_to_arrow_schema, dummy_arrow_schema, mock_asyncio_to_thread, ): - def redact_content(content): - return "[REDACTED]" - - config = BigQueryLoggerConfig(content_formatter=redact_content) + """Tests a formatter that modifies the JSON structure (Pruning & Normalization).""" + + def mutate_payload(data): + if isinstance(data, dict): + # 1. Pruning: Remove system_instruction + if "system_instruction" in data: + del data["system_instruction"] + # 2. Normalization: Uppercase model name + if "model" in data and isinstance(data["model"], str): + data["model"] = data["model"].upper() + return data + + config = BigQueryLoggerConfig(content_formatter=mutate_payload) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) await plugin._ensure_init() mock_write_client.append_rows.reset_mock() - user_message = types.Content(parts=[types.Part(text="Secret message")]) - await plugin.on_user_message_callback( - invocation_context=invocation_context, user_message=user_message - ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - assert log_entry["content"] == "User Content: [REDACTED]" - - @pytest.mark.asyncio - async def test_content_formatter_error( - self, - mock_write_client, - invocation_context, - mock_auth_default, - mock_bq_client, - mock_to_arrow_schema, - dummy_arrow_schema, - mock_asyncio_to_thread, - ): - def error_formatter(content): - raise ValueError("Formatter failed") - - config = BigQueryLoggerConfig(content_formatter=error_formatter) - plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( - PROJECT_ID, DATASET_ID, TABLE_ID, config + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + system_instruction=types.Content(parts=[types.Part(text="Sys")]) + ), + contents=[types.Content(role="user", parts=[types.Part(text="User")])], ) - await plugin._ensure_init() - mock_write_client.append_rows.reset_mock() - user_message = types.Content(parts=[types.Part(text="Secret message")]) - await plugin.on_user_message_callback( - invocation_context=invocation_context, user_message=user_message + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request ) await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - assert log_entry["content"] == "User Content: [FORMATTING FAILED]" + + # Parse JSON + content = json.loads(log_entry["content"]) + + # Verify mutation + assert "system_instruction" not in content + assert content["model"] == "GEMINI-PRO" + assert content["prompt"][0]["role"] == "user" @pytest.mark.asyncio - async def test_max_content_length( + async def test_max_content_length_smart_truncation( self, mock_write_client, invocation_context, @@ -405,7 +393,8 @@ async def test_max_content_length( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=40) + # Config limit to 10 chars + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -413,45 +402,21 @@ async def test_max_content_length( mock_write_client.append_rows.reset_mock() # Test User Message Truncation - user_message = types.Content( - parts=[types.Part(text="12345678901234567890123456789012345678901")] - ) # 41 chars + long_text = "123456789012345" # 15 chars + user_message = types.Content(parts=[types.Part(text=long_text)]) + await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - assert ( - log_entry["content"] - == "User Content: text: '1234567890123456789012345678901234567890...' " - ) - assert log_entry["is_truncated"] - mock_write_client.append_rows.reset_mock() - # Test before_model_callback full content truncation - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - config=types.GenerateContentConfig( - system_instruction=types.Content( - parts=[types.Part(text="System Instruction")] - ) - ), - contents=[ - types.Content(role="user", parts=[types.Part(text="Prompt")]) - ], - ) - await plugin.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - # Full content: "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt: System Instruction" - # Truncated to 40 chars + ...: - expected_content = "Model: gemini-pro | Prompt: user: text: ..." - assert log_entry["content"] == expected_content - assert log_entry["is_truncated"] + content = json.loads(log_entry["content"]) + + # Verify "1234567890...[TRUNCATED]" + assert content["text"] == "1234567890...[TRUNCATED]" + # Verify it is still valid JSON + assert isinstance(content, dict) @pytest.mark.asyncio async def test_max_content_length_tool_args( @@ -464,7 +429,8 @@ async def test_max_content_length_tool_args( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=80) + # Limit 10 chars + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -475,24 +441,21 @@ async def test_max_content_length_tool_args( base_tool_lib.BaseTool, instance=True, spec_set=True ) type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") + type(mock_tool).description = mock.PropertyMock(return_value="Desc") - # Args length > 80 - # {"param": "A" * 50} is ~60 chars. - # Prefix is ~57 chars. Total ~117 chars. + # Args contain a long string + long_val = "A" * 20 await plugin.before_tool_callback( tool=mock_tool, - tool_args={"param": "A" * 50}, + tool_args={"param": long_val}, tool_context=tool_context, ) await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + content = json.loads(log_entry["content"]) - assert 'Arguments: {"param": "AAAAA' in log_entry["content"] - assert log_entry["content"].endswith("...") - assert len(log_entry["content"]) == 83 # 80 + 3 dots - assert log_entry["is_truncated"] + # Verify truncation happened inside the JSON structure + assert content["arguments"]["param"] == "AAAAAAAAAA...[TRUNCATED]" @pytest.mark.asyncio async def test_max_content_length_tool_result( @@ -505,7 +468,7 @@ async def test_max_content_length_tool_result( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=80) + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -517,23 +480,18 @@ async def test_max_content_length_tool_result( ) type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - # Result length > 80 - # {"res": "A" * 60} is ~70 chars. - # Prefix is ~27 chars. Total ~97 chars. + long_res = "A" * 20 await plugin.after_tool_callback( tool=mock_tool, tool_args={}, tool_context=tool_context, - result={"res": "A" * 60}, + result={"res": long_res}, ) await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + content = json.loads(log_entry["content"]) - assert 'Result: {"res": "AAAAA' in log_entry["content"] - assert log_entry["content"].endswith("...") - assert len(log_entry["content"]) == 83 # 80 + 3 dots - assert log_entry["is_truncated"] + assert content["result"]["res"] == "AAAAAAAAAA...[TRUNCATED]" @pytest.mark.asyncio async def test_max_content_length_tool_error( @@ -546,7 +504,7 @@ async def test_max_content_length_tool_error( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=80) + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -558,23 +516,18 @@ async def test_max_content_length_tool_error( ) type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - # Args length > 80 - # {"arg": "A" * 60} is ~70 chars. - # Prefix is ~28 chars. Total ~98 chars. + long_arg = "A" * 20 await plugin.on_tool_error_callback( tool=mock_tool, - tool_args={"arg": "A" * 60}, + tool_args={"arg": long_arg}, tool_context=tool_context, error=ValueError("Oops"), ) await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + content = json.loads(log_entry["content"]) - assert 'Arguments: {"arg": "AAAAA' in log_entry["content"] - assert log_entry["content"].endswith("...") - assert len(log_entry["content"]) == 83 # 80 + 3 dots - assert log_entry["is_truncated"] + assert content["arguments"]["arg"] == "AAAAAAAAAA...[TRUNCATED]" @pytest.mark.asyncio async def test_on_user_message_callback_logs_correctly( @@ -591,8 +544,9 @@ async def test_on_user_message_callback_logs_correctly( await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "USER_MESSAGE_RECEIVED") - assert log_entry["content"] == "User Content: text: 'What is up?'" - assert not log_entry["is_truncated"] + + content = json.loads(log_entry["content"]) + assert content["text"] == "What is up?" @pytest.mark.asyncio async def test_on_event_callback_tool_call( @@ -616,10 +570,11 @@ async def test_on_event_callback_tool_call( await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_CALL", agent="MyTestAgent") - assert "call: get_weather" in log_entry["content"] - assert log_entry["timestamp"] == datetime.datetime( - 2025, 10, 22, 10, 0, 0, tzinfo=datetime.timezone.utc - ) + + # Verify Generic Event JSON structure + content = json.loads(log_entry["content"]) + assert content["raw_role"] == "MyTestAgent" + assert content["tool_calls"] == ["get_weather"] @pytest.mark.asyncio async def test_on_event_callback_model_response( @@ -642,10 +597,194 @@ async def test_on_event_callback_model_response( await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent") - assert "text: 'Hello there!'" in log_entry["content"] - assert log_entry["timestamp"] == datetime.datetime( - 2025, 10, 22, 11, 0, 0, tzinfo=datetime.timezone.utc + + content = json.loads(log_entry["content"]) + assert content["text"] == "Hello there!" + + @pytest.mark.asyncio + async def test_before_model_callback_logs_structure( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """Covers combined logic of params and tools in one structured test.""" + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + temperature=0.5, + top_p=0.9, + system_instruction=types.Content(parts=[types.Part(text="Sys")]), + ), + contents=[types.Content(role="user", parts=[types.Part(text="User")])], + ) + # Manually set tools_dict + llm_request.tools_dict = {"tool1": "func1"} + + await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_REQUEST") + + # Verify structured JSON + content = json.loads(log_entry["content"]) + assert content["model"] == "gemini-pro" + assert content["params"]["temperature"] == 0.5 + assert content["params"]["top_p"] == 0.9 + assert "tool1" in content["tools_available"] + assert content["system_instruction"] == "Sys" + assert content["prompt"][0]["role"] == "user" + assert content["prompt"][0]["parts"][0]["type"] == "text" + assert content["prompt"][0]["parts"][0]["text"] == "User" + + @pytest.mark.asyncio + async def test_after_model_callback_text_response( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(text="Model response")]), + usage_metadata=types.UsageMetadata( + prompt_token_count=10, total_token_count=15 + ), + ) + await bq_plugin_inst.after_model_callback( + callback_context=callback_context, llm_response=llm_response + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_RESPONSE") + + content = json.loads(log_entry["content"]) + assert content["response_content"][0]["type"] == "text" + assert content["response_content"][0]["text"] == "Model response" + assert content["usage"]["prompt_tokens"] == 10 + assert content["usage"]["total_tokens"] == 15 + + @pytest.mark.asyncio + async def test_after_model_callback_tool_call( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + tool_fc = types.FunctionCall(name="get_weather", args={"location": "Paris"}) + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(function_call=tool_fc)]), + usage_metadata=types.UsageMetadata( + prompt_token_count=10, total_token_count=15 + ), + ) + await bq_plugin_inst.after_model_callback( + callback_context=callback_context, llm_response=llm_response + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_RESPONSE") + + content = json.loads(log_entry["content"]) + # Verify Tool Call structure + assert content["response_content"][0]["type"] == "function_call" + assert content["response_content"][0]["name"] == "get_weather" + assert content["response_content"][0]["args"]["location"] == "Paris" + + @pytest.mark.asyncio + async def test_before_tool_callback_logs_correctly( + self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema + ): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + await bq_plugin_inst.before_tool_callback( + tool=mock_tool, tool_args={"param": "value"}, tool_context=tool_context + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "TOOL_STARTING") + + content = json.loads(log_entry["content"]) + assert content["tool_name"] == "MyTool" + assert content["description"] == "Description" + assert content["arguments"]["param"] == "value" + + @pytest.mark.asyncio + async def test_after_tool_callback_logs_correctly( + self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema + ): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + await bq_plugin_inst.after_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=tool_context, + result={"status": "success"}, + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "TOOL_COMPLETED") + + content = json.loads(log_entry["content"]) + assert content["tool_name"] == "MyTool" + assert content["result"]["status"] == "success" + + @pytest.mark.asyncio + async def test_on_model_error_callback_logs_correctly( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(parts=[types.Part(text="Prompt")])], + ) + error = ValueError("LLM failed") + await bq_plugin_inst.on_model_error_callback( + callback_context=callback_context, llm_request=llm_request, error=error + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_ERROR") + assert log_entry["content"] is None + assert log_entry["error_message"] == "LLM failed" + + @pytest.mark.asyncio + async def test_on_tool_error_callback_logs_correctly( + self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema + ): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + error = TimeoutError("Tool timed out") + await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"param": "value"}, + tool_context=tool_context, + error=error, + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "TOOL_ERROR") + + content = json.loads(log_entry["content"]) + assert content["tool_name"] == "MyTool" + assert content["arguments"]["param"] == "value" + assert log_entry["error_message"] == "Tool timed out" @pytest.mark.asyncio async def test_bigquery_client_initialization_failure( @@ -678,7 +817,6 @@ async def test_bigquery_client_initialization_failure( async def test_bigquery_insert_error_does_not_raise( self, bq_plugin_inst, mock_write_client, invocation_context ): - async def fake_append_rows_with_error(requests, **kwargs): mock_append_rows_response = mock.MagicMock() mock_append_rows_response.row_errors = [] # No row errors @@ -725,9 +863,9 @@ async def fake_append_rows_with_schema_error(requests, **kwargs): ) await asyncio.sleep(0.01) mock_log_error.assert_called_with( - "BQ Plugin: Schema Mismatch Error. The BigQuery table schema may be" - " incorrect or out of sync with the plugin. Please verify the table" - " definition. Details: %s", + "BQ Plugin: Schema Mismatch. You may need to delete the existing" + " table if you migrated from STRING content to JSON content." + " Details: %s", "Schema mismatch: Field 'new_field' not found in table.", ) @@ -735,8 +873,6 @@ async def fake_append_rows_with_schema_error(requests, **kwargs): async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client): await bq_plugin_inst.close() mock_write_client.transport.close.assert_called_once() - # bq_client might not be closed if it wasn't created or if close() failed, - # but here it should be. # in the new implementation we verify attributes are reset assert bq_plugin_inst._write_client is None assert bq_plugin_inst._bq_client is None @@ -789,7 +925,9 @@ async def test_before_agent_callback_logs_correctly( await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "AGENT_STARTING") - assert log_entry["content"] == "Agent Name: MyTestAgent" + + content = json.loads(log_entry["content"]) + assert content["target_agent"] == "MyTestAgent" @pytest.mark.asyncio async def test_after_agent_callback_logs_correctly( @@ -806,216 +944,9 @@ async def test_after_agent_callback_logs_correctly( await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "AGENT_COMPLETED") - assert log_entry["content"] == "Agent Name: MyTestAgent" - - @pytest.mark.asyncio - async def test_before_model_callback_logs_correctly( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - contents=[ - types.Content(role="user", parts=[types.Part(text="Prompt")]) - ], - ) - await bq_plugin_inst.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_REQUEST") - assert ( - log_entry["content"] - == "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt:" - " Empty" - ) - - @pytest.mark.asyncio - async def test_before_model_callback_with_params_and_tools( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - config=types.GenerateContentConfig( - temperature=0.5, - top_p=0.9, - system_instruction=types.Content(parts=[types.Part(text="Sys")]), - ), - contents=[types.Content(role="user", parts=[types.Part(text="User")])], - ) - # Manually set tools_dict as it is excluded from init - llm_request.tools_dict = {"tool1": "func1", "tool2": "func2"} - - await bq_plugin_inst.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_REQUEST") - # Order: Model | Params | Tools | Prompt | System Prompt - # Note: Params order depends on dict iteration but here we construct it deterministically in code? - # The code does: params_to_log["temperature"] = ... then "top_p" = ... - # So order should be temperature, top_p. - assert "Model: gemini-pro" in log_entry["content"] - assert "Params: {temperature=0.5, top_p=0.9}" in log_entry["content"] - assert "Available Tools: ['tool1', 'tool2']" in log_entry["content"] - assert "Prompt: user: text: 'User'" in log_entry["content"] - assert "System Prompt: Sys" in log_entry["content"] - - @pytest.mark.asyncio - async def test_after_model_callback_text_response( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_response = llm_response_lib.LlmResponse( - content=types.Content(parts=[types.Part(text="Model response")]), - usage_metadata=types.UsageMetadata( - prompt_token_count=10, total_token_count=15 - ), - ) - await bq_plugin_inst.after_model_callback( - callback_context=callback_context, llm_response=llm_response - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_RESPONSE") - assert ( - "Tool Name: text_response, text: 'Model response'" - in log_entry["content"] - ) - assert "Token Usage:" in log_entry["content"] - assert "prompt: 10" in log_entry["content"] - assert "total: 15" in log_entry["content"] - assert log_entry["error_message"] is None - - @pytest.mark.asyncio - async def test_after_model_callback_tool_call( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - tool_fc = types.FunctionCall(name="get_weather", args={"location": "Paris"}) - llm_response = llm_response_lib.LlmResponse( - content=types.Content(parts=[types.Part(function_call=tool_fc)]), - usage_metadata=types.UsageMetadata( - prompt_token_count=10, total_token_count=15 - ), - ) - await bq_plugin_inst.after_model_callback( - callback_context=callback_context, llm_response=llm_response - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_RESPONSE") - assert "Tool Name: get_weather" in log_entry["content"] - assert "Token Usage:" in log_entry["content"] - assert "prompt: 10" in log_entry["content"] - assert "total: 15" in log_entry["content"] - assert log_entry["error_message"] is None - @pytest.mark.asyncio - async def test_before_tool_callback_logs_correctly( - self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema - ): - mock_tool = mock.create_autospec( - base_tool_lib.BaseTool, instance=True, spec_set=True - ) - type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") - await bq_plugin_inst.before_tool_callback( - tool=mock_tool, tool_args={"param": "value"}, tool_context=tool_context - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "TOOL_STARTING") - assert ( - log_entry["content"] - == 'Tool Name: MyTool, Description: Description, Arguments: {"param":' - ' "value"}' - ) - - @pytest.mark.asyncio - async def test_after_tool_callback_logs_correctly( - self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema - ): - mock_tool = mock.create_autospec( - base_tool_lib.BaseTool, instance=True, spec_set=True - ) - type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") - await bq_plugin_inst.after_tool_callback( - tool=mock_tool, - tool_args={}, - tool_context=tool_context, - result={"status": "success"}, - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "TOOL_COMPLETED") - assert ( - log_entry["content"] - == 'Tool Name: MyTool, Result: {"status": "success"}' - ) - - @pytest.mark.asyncio - async def test_on_model_error_callback_logs_correctly( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - contents=[types.Content(parts=[types.Part(text="Prompt")])], - ) - error = ValueError("LLM failed") - await bq_plugin_inst.on_model_error_callback( - callback_context=callback_context, llm_request=llm_request, error=error - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_ERROR") - assert log_entry["content"] is None - assert log_entry["error_message"] == "LLM failed" - - @pytest.mark.asyncio - async def test_on_tool_error_callback_logs_correctly( - self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema - ): - mock_tool = mock.create_autospec( - base_tool_lib.BaseTool, instance=True, spec_set=True - ) - type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") - error = TimeoutError("Tool timed out") - await bq_plugin_inst.on_tool_error_callback( - tool=mock_tool, - tool_args={"param": "value"}, - tool_context=tool_context, - error=error, - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "TOOL_ERROR") - assert ( - log_entry["content"] - == 'Tool Name: MyTool, Arguments: {"param": "value"}' - ) - assert log_entry["error_message"] == "Tool timed out" + content = json.loads(log_entry["content"]) + assert content["target_agent"] == "MyTestAgent" @pytest.mark.asyncio async def test_table_creation_options( @@ -1039,9 +970,7 @@ async def test_table_creation_options( assert table_arg.time_partitioning.type_ == "DAY" assert table_arg.time_partitioning.field == "timestamp" assert table_arg.clustering_fields == ["event_type", "agent", "user_id"] - # Verify schema descriptions are present (spot check) - timestamp_field = next(f for f in table_arg.schema if f.name == "timestamp") - assert ( - timestamp_field.description - == "The UTC time at which the event was logged." - ) + + # Verify schema type for content is JSON + content_field = next(f for f in table_arg.schema if f.name == "content") + assert content_field.field_type == "JSON" From a3ace990a696707d5d3f97a9c8f9f72ec41e02e9 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 17:43:45 -0300 Subject: [PATCH 02/10] fix: Ensure close() resets clients to None and use it in tests --- .../bigquery_agent_analytics_plugin.py | 37 ++++++--- .../test_bigquery_agent_analytics_plugin.py | 76 ++++++++++++------- 2 files changed, 74 insertions(+), 39 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index fc5d010281..a83a461d0d 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -34,7 +34,6 @@ from google.genai import types import pyarrow as pa -from .. import version from ..agents.base_agent import BaseAgent from ..agents.callback_context import CallbackContext from ..events.event import Event @@ -382,7 +381,7 @@ async def _ensure_init(self): scopes=["https://www.googleapis.com/auth/cloud-platform"], ) client_info = gapic_client_info.ClientInfo( - user_agent=f"google-adk-bq-logger/{version.__version__}" + user_agent="google-adk-bq-logger" ) self._bq_client = bigquery.Client( project=self._project_id, credentials=creds, client_info=client_info @@ -585,7 +584,7 @@ async def on_user_message_callback( if user_message and user_message.parts: text_content = " ".join([p.text for p in user_message.parts if p.text]) - payload = {"text": text_content} + payload = {"text": text_content if text_content else None} await self._log( { @@ -647,7 +646,7 @@ async def on_event_callback( "text": " ".join(text_parts) if text_parts else None, "tool_calls": tool_calls if tool_calls else None, "tool_responses": tool_responses if tool_responses else None, - "raw_role": event.author, + "raw_role": event.author if event.author else None, } await self._log( @@ -750,8 +749,8 @@ async def before_model_callback( params["max_output_tokens"] = cfg.max_output_tokens # 2. System Instruction - system_instr = "None" - if llm_request.config and llm_request.config.system_instruction: + system_instr = None + if llm_request.config and llm_request.config.system_instruction is not None: si = llm_request.config.system_instruction if isinstance(si, str): system_instr = si @@ -759,6 +758,16 @@ async def before_model_callback( system_instr = "".join(p.text for p in si.parts if p.text) elif isinstance(si, types.Part): system_instr = si.text + elif hasattr(si, "__iter__"): + texts = [] + for item in si: + if isinstance(item, str): + texts.append(item) + elif isinstance(item, types.Part) and item.text: + texts.append(item.text) + system_instr = "".join(texts) + else: + system_instr = str(si) # 3. Prompt History (Simplified structure for JSON) prompt_history = [] @@ -843,7 +852,10 @@ async def after_model_callback( ), } - payload = {"response_content": content_parts, "usage": usage} + payload = { + "response_content": content_parts if content_parts else None, + "usage": usage if usage else None, + } await self._log( { @@ -876,10 +888,11 @@ async def before_tool_callback( If individual string fields exceed `max_content_length`, they are truncated to preserve the valid JSON structure. """ + payload = { - "tool_name": tool.name, - "description": tool.description, - "arguments": tool_args, + "tool_name": tool.name if tool.name else None, + "description": tool.description if tool.description else None, + "arguments": tool_args if tool_args else None, } await self._log( { @@ -910,7 +923,7 @@ async def after_tool_callback( If individual string fields exceed `max_content_length`, they are truncated to preserve the valid JSON structure. """ - payload = {"tool_name": tool.name, "result": result} + payload = {"tool_name": tool.name if tool.name else None, "result": result if result else None} await self._log( { "event_type": "TOOL_COMPLETED", @@ -964,7 +977,7 @@ async def on_tool_error_callback( If individual string fields exceed `max_content_length`, they are truncated to preserve the valid JSON structure. """ - payload = {"tool_name": tool.name, "arguments": tool_args} + payload = {"tool_name": tool.name if tool.name else None, "arguments": tool_args if tool_args else None} await self._log( { "event_type": "TOOL_ERROR", diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 0dd3e1617c..3301c1d833 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -143,7 +143,7 @@ async def fake_append_rows(requests, **kwargs): @pytest.fixture def dummy_arrow_schema(): - # UPDATED: content is pa.string() because JSON is serialized to string before Arrow + # content is pa.string() because JSON is serialized to string before Arrow return pa.schema([ pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False), pa.field("event_type", pa.string(), nullable=True), @@ -259,6 +259,9 @@ async def test_plugin_disabled( invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) + # Wait for background tasks + await plugin.close() + mock_auth_default.assert_not_called() mock_bq_client.assert_not_called() mock_write_client.append_rows.assert_not_called() @@ -289,15 +292,25 @@ async def test_event_allowlist( await plugin.before_model_callback( callback_context=callback_context, llm_request=llm_request ) - await asyncio.sleep(0.01) # Allow background task to run + await plugin.close() # Wait for write mock_write_client.append_rows.assert_called_once() mock_write_client.append_rows.reset_mock() + # Re-init plugin logic since close() shuts it down, but for this test we want to test denial + # However, close() cleans up clients. We should probably create a new plugin or just check that the task was not created. + # But on_user_message_callback will try to log. + # To keep it simple, let's just use a fresh plugin for the second part or assume close() resets state enough to re-run _ensure_init if needed, + # but _ensure_init is called inside _perform_write. + # Actually, close() sets _is_shutting_down to True, so further logs are ignored. + # So we need a new plugin instance or reset _is_shutting_down. + plugin._is_shutting_down = False + user_message = types.Content(parts=[types.Part(text="What is up?")]) await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) # Allow background task to run + # Since it's denied, no task is created. close() would wait if there was one. + await plugin.close() mock_write_client.append_rows.assert_not_called() @pytest.mark.asyncio @@ -322,11 +335,14 @@ async def test_event_denylist( await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) + await plugin.close() mock_write_client.append_rows.assert_not_called() + # Reset for next call + plugin._is_shutting_down = False + await plugin.before_run_callback(invocation_context=invocation_context) - await asyncio.sleep(0.01) + await plugin.close() mock_write_client.append_rows.assert_called_once() @pytest.mark.asyncio @@ -370,7 +386,7 @@ def mutate_payload(data): await plugin.before_model_callback( callback_context=callback_context, llm_request=llm_request ) - await asyncio.sleep(0.01) + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) # Parse JSON @@ -408,7 +424,7 @@ async def test_max_content_length_smart_truncation( await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) content = json.loads(log_entry["content"]) @@ -450,7 +466,7 @@ async def test_max_content_length_tool_args( tool_args={"param": long_val}, tool_context=tool_context, ) - await asyncio.sleep(0.01) + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) content = json.loads(log_entry["content"]) @@ -487,7 +503,7 @@ async def test_max_content_length_tool_result( tool_context=tool_context, result={"res": long_res}, ) - await asyncio.sleep(0.01) + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) content = json.loads(log_entry["content"]) @@ -523,7 +539,7 @@ async def test_max_content_length_tool_error( tool_context=tool_context, error=ValueError("Oops"), ) - await asyncio.sleep(0.01) + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) content = json.loads(log_entry["content"]) @@ -541,10 +557,11 @@ async def test_on_user_message_callback_logs_correctly( await bq_plugin_inst.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "USER_MESSAGE_RECEIVED") + # UPDATED ASSERTION: Check JSON structure content = json.loads(log_entry["content"]) assert content["text"] == "What is up?" @@ -567,7 +584,7 @@ async def test_on_event_callback_tool_call( await bq_plugin_inst.on_event_callback( invocation_context=invocation_context, event=event ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_CALL", agent="MyTestAgent") @@ -594,7 +611,7 @@ async def test_on_event_callback_model_response( await bq_plugin_inst.on_event_callback( invocation_context=invocation_context, event=event ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent") @@ -625,7 +642,7 @@ async def test_before_model_callback_logs_structure( await bq_plugin_inst.before_model_callback( callback_context=callback_context, llm_request=llm_request ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "LLM_REQUEST") @@ -657,10 +674,11 @@ async def test_after_model_callback_text_response( await bq_plugin_inst.after_model_callback( callback_context=callback_context, llm_response=llm_response ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "LLM_RESPONSE") + # UPDATED ASSERTION: Check structured JSON content = json.loads(log_entry["content"]) assert content["response_content"][0]["type"] == "text" assert content["response_content"][0]["text"] == "Model response" @@ -685,7 +703,7 @@ async def test_after_model_callback_tool_call( await bq_plugin_inst.after_model_callback( callback_context=callback_context, llm_response=llm_response ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "LLM_RESPONSE") @@ -707,10 +725,11 @@ async def test_before_tool_callback_logs_correctly( await bq_plugin_inst.before_tool_callback( tool=mock_tool, tool_args={"param": "value"}, tool_context=tool_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_STARTING") + # UPDATED ASSERTION: Check structured JSON content = json.loads(log_entry["content"]) assert content["tool_name"] == "MyTool" assert content["description"] == "Description" @@ -731,10 +750,11 @@ async def test_after_tool_callback_logs_correctly( tool_context=tool_context, result={"status": "success"}, ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_COMPLETED") + # UPDATED ASSERTION: Check structured JSON content = json.loads(log_entry["content"]) assert content["tool_name"] == "MyTool" assert content["result"]["status"] == "success" @@ -755,7 +775,7 @@ async def test_on_model_error_callback_logs_correctly( await bq_plugin_inst.on_model_error_callback( callback_context=callback_context, llm_request=llm_request, error=error ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "LLM_ERROR") assert log_entry["content"] is None @@ -777,7 +797,7 @@ async def test_on_tool_error_callback_logs_correctly( tool_context=tool_context, error=error, ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_ERROR") @@ -809,7 +829,9 @@ async def test_bigquery_client_initialization_failure( invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) - await asyncio.sleep(0.01) + # Wait for the background task (which logs the error) to complete + await plugin_with_fail.close() + mock_log_error.assert_any_call("BQ Plugin: Init Failed:", exc_info=True) mock_write_client.append_rows.assert_not_called() @@ -832,7 +854,7 @@ async def fake_append_rows_with_error(requests, **kwargs): invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() mock_log_error.assert_called_with( "BQ Plugin: Write Error: %s", "Test BQ Error" ) @@ -861,7 +883,7 @@ async def fake_append_rows_with_schema_error(requests, **kwargs): invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() mock_log_error.assert_called_with( "BQ Plugin: Schema Mismatch. You may need to delete the existing" " table if you migrated from STRING content to JSON content." @@ -889,7 +911,7 @@ async def test_before_run_callback_logs_correctly( await bq_plugin_inst.before_run_callback( invocation_context=invocation_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "INVOCATION_STARTING") assert log_entry["content"] is None @@ -905,7 +927,7 @@ async def test_after_run_callback_logs_correctly( await bq_plugin_inst.after_run_callback( invocation_context=invocation_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "INVOCATION_COMPLETED") assert log_entry["content"] is None @@ -922,7 +944,7 @@ async def test_before_agent_callback_logs_correctly( await bq_plugin_inst.before_agent_callback( agent=mock_agent, callback_context=callback_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "AGENT_STARTING") @@ -941,7 +963,7 @@ async def test_after_agent_callback_logs_correctly( await bq_plugin_inst.after_agent_callback( agent=mock_agent, callback_context=callback_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "AGENT_COMPLETED") From aa99faa333576f7afca38fcbd6618fbb756caf0a Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 18:48:38 -0300 Subject: [PATCH 03/10] Address review feedback on BQ plugin JSON structure, timestamps, linting --- .../bigquery_agent_analytics_plugin.py | 39 +++++++++++++------ .../test_bigquery_agent_analytics_plugin.py | 6 ++- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index a83a461d0d..825bb1368e 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -28,9 +28,9 @@ from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.cloud import bigquery +from google.cloud import bigquery_storage_v1 from google.cloud.bigquery import schema as bq_schema from google.cloud.bigquery_storage_v1 import types as bq_storage_types -from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types import pyarrow as pa @@ -221,7 +221,7 @@ class BigQueryLoggerConfig: event_allowlist: Optional[List[str]] = None event_denylist: Optional[List[str]] = None # Custom formatter is discouraged now that we use JSON, but kept for compat - content_formatter: Optional[Callable[[Any], str]] = None + content_formatter: Optional[Callable[[dict], dict]] = None shutdown_timeout: float = 5.0 client_close_timeout: float = 2.0 # Increased default limit to 50KB since we truncate per-field, not per-row @@ -307,7 +307,11 @@ def __init__( ) self._config = config if config else BigQueryLoggerConfig() self._bq_client: bigquery.Client | None = None - self._write_client: BigQueryWriteAsyncClient | None = None + # Type alias update: Use the class from the top-level package import + self._write_client: ( + bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient + | None + ) = None self._init_lock: asyncio.Lock | None = None self._arrow_schema: pa.Schema | None = None self._background_tasks: set[asyncio.Task] = set() @@ -407,7 +411,8 @@ def create_resources(): await asyncio.to_thread(create_resources) - self._write_client = BigQueryWriteAsyncClient( + # Fix: Use the top-level package import to avoid "cli" substring in path + self._write_client = bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient( credentials=creds, client_info=client_info, ) @@ -446,7 +451,11 @@ async def _perform_write(self, row: dict): ): if resp.error.code != 0: msg = resp.error.message - if "schema mismatch" in msg.lower(): + if ( + "schema mismatch" in msg.lower() + or "field" in msg.lower() + or "type" in msg.lower() + ): logging.error( "BQ Plugin: Schema Mismatch. You may need to delete the" " existing table if you migrated from STRING content to JSON" @@ -462,7 +471,7 @@ async def _perform_write(self, row: dict): except asyncio.CancelledError: if not self._is_shutting_down: logging.warning("BQ Plugin: Write task cancelled unexpectedly.") - except Exception: + except Exception as e: logging.error("BQ Plugin: Write Failed:", exc_info=True) async def _log(self, data: dict, content_payload: Any = None): @@ -657,6 +666,7 @@ async def on_event_callback( "invocation_id": invocation_context.invocation_id, "user_id": invocation_context.session.user_id, "error_message": event.error_message, + "timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc), }, content_payload=payload, ) @@ -790,14 +800,14 @@ async def before_model_callback( payload = { "model": llm_request.model or "default", - "params": params, + "params": params if params else None, "tools_available": ( list(llm_request.tools_dict.keys()) if llm_request.tools_dict - else [] + else None ), "system_instruction": system_instr, - "prompt": prompt_history, + "prompt": prompt_history if prompt_history else None, } await self._log( @@ -888,7 +898,6 @@ async def before_tool_callback( If individual string fields exceed `max_content_length`, they are truncated to preserve the valid JSON structure. """ - payload = { "tool_name": tool.name if tool.name else None, "description": tool.description if tool.description else None, @@ -923,7 +932,10 @@ async def after_tool_callback( If individual string fields exceed `max_content_length`, they are truncated to preserve the valid JSON structure. """ - payload = {"tool_name": tool.name if tool.name else None, "result": result if result else None} + payload = { + "tool_name": tool.name if tool.name else None, + "result": result if result else None, + } await self._log( { "event_type": "TOOL_COMPLETED", @@ -977,7 +989,10 @@ async def on_tool_error_callback( If individual string fields exceed `max_content_length`, they are truncated to preserve the valid JSON structure. """ - payload = {"tool_name": tool.name if tool.name else None, "arguments": tool_args if tool_args else None} + payload = { + "tool_name": tool.name if tool.name else None, + "arguments": tool_args if tool_args else None, + } await self._log( { "event_type": "TOOL_ERROR", diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 3301c1d833..fb9e8cc749 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -123,8 +123,10 @@ def mock_bq_client(): @pytest.fixture def mock_write_client(): - with mock.patch.object( - bigquery_agent_analytics_plugin, "BigQueryWriteAsyncClient", autospec=True + # Updated patch path to match the new import structure in src + with mock.patch( + "google.cloud.bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient", + autospec=True, ) as mock_cls: mock_client = mock_cls.return_value mock_client.transport = mock.AsyncMock() From 04d278dab1b0d2dd60b116bf6672fff7be141cd4 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 19:01:33 -0300 Subject: [PATCH 04/10] Update src/google/adk/plugins/bigquery_agent_analytics_plugin.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 825bb1368e..809e683b10 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -236,8 +236,8 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> Any: return obj elif isinstance(obj, dict): return {k: _recursive_smart_truncate(v, max_len) for k, v in obj.items()} - elif isinstance(obj, list): - return [_recursive_smart_truncate(i, max_len) for i in obj] + elif isinstance(obj, (list, tuple)): + return type(obj)(_recursive_smart_truncate(i, max_len) for i in obj) else: return obj From 892cb16c1a7823f77bc85159081b423bd45c7624 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 19:04:34 -0300 Subject: [PATCH 05/10] Add missing content formatter error test case --- .../test_bigquery_agent_analytics_plugin.py | 86 +++++++++++++++---- 1 file changed, 71 insertions(+), 15 deletions(-) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index fb9e8cc749..a252ed125a 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -297,22 +297,26 @@ async def test_event_allowlist( await plugin.close() # Wait for write mock_write_client.append_rows.assert_called_once() mock_write_client.append_rows.reset_mock() - - # Re-init plugin logic since close() shuts it down, but for this test we want to test denial - # However, close() cleans up clients. We should probably create a new plugin or just check that the task was not created. - # But on_user_message_callback will try to log. - # To keep it simple, let's just use a fresh plugin for the second part or assume close() resets state enough to re-run _ensure_init if needed, - # but _ensure_init is called inside _perform_write. - # Actually, close() sets _is_shutting_down to True, so further logs are ignored. - # So we need a new plugin instance or reset _is_shutting_down. + # Re-init plugin logic since close() shuts it down plugin._is_shutting_down = False + # REFACTOR: Use a fresh plugin instance for the denied case + plugin_denied = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + ) + await plugin_denied._ensure_init() + # Inject the same mock_write_client + plugin_denied._write_client = mock_write_client + plugin_denied._arrow_schema = plugin._arrow_schema + user_message = types.Content(parts=[types.Part(text="What is up?")]) - await plugin.on_user_message_callback( + await plugin_denied.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) # Since it's denied, no task is created. close() would wait if there was one. - await plugin.close() + await plugin_denied.close() mock_write_client.append_rows.assert_not_called() @pytest.mark.asyncio @@ -330,6 +334,8 @@ async def test_event_denylist( plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) + # Reset for next call + plugin._is_shutting_down = False await plugin._ensure_init() mock_write_client.append_rows.reset_mock() @@ -340,11 +346,21 @@ async def test_event_denylist( await plugin.close() mock_write_client.append_rows.assert_not_called() - # Reset for next call - plugin._is_shutting_down = False + # REFACTOR: Use a fresh plugin instance for the allowed case + plugin_allowed = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + ) + await plugin_allowed._ensure_init() + # Inject the same mock_write_client + plugin_allowed._write_client = mock_write_client + plugin_allowed._arrow_schema = plugin._arrow_schema - await plugin.before_run_callback(invocation_context=invocation_context) - await plugin.close() + await plugin_allowed.before_run_callback( + invocation_context=invocation_context + ) + await plugin_allowed.close() mock_write_client.append_rows.assert_called_once() @pytest.mark.asyncio @@ -399,6 +415,44 @@ def mutate_payload(data): assert content["model"] == "GEMINI-PRO" assert content["prompt"][0]["role"] == "user" + @pytest.mark.asyncio + async def test_content_formatter_error_fallback( + self, + mock_write_client, + invocation_context, + mock_auth_default, + mock_bq_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Tests that if content_formatter fails, the original payload is used.""" + + def error_formatter(data): + raise ValueError("Formatter failed") + + config = BigQueryLoggerConfig(content_formatter=error_formatter) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + await plugin._ensure_init() + mock_write_client.append_rows.reset_mock() + + user_message = types.Content(parts=[types.Part(text="Original message")]) + + # This triggers the log. Internal logic catches exception and proceeds. + await plugin.on_user_message_callback( + invocation_context=invocation_context, user_message=user_message + ) + await plugin.close() + + mock_write_client.append_rows.assert_called_once() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + + # Verify that despite the error, we still got the original data + content = json.loads(log_entry["content"]) + assert content["text"] == "Original message" + @pytest.mark.asyncio async def test_max_content_length_smart_truncation( self, @@ -725,7 +779,9 @@ async def test_before_tool_callback_logs_correctly( type(mock_tool).name = mock.PropertyMock(return_value="MyTool") type(mock_tool).description = mock.PropertyMock(return_value="Description") await bq_plugin_inst.before_tool_callback( - tool=mock_tool, tool_args={"param": "value"}, tool_context=tool_context + tool=mock_tool, + tool_args={"param": "value"}, + tool_context=tool_context, ) await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) From ec778ab5ac8dceb005e347219cfc50ce121aa7b7 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 19:23:57 -0300 Subject: [PATCH 06/10] getting back version import --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 809e683b10..a09643f2fa 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -34,6 +34,7 @@ from google.genai import types import pyarrow as pa +from .. import version from ..agents.base_agent import BaseAgent from ..agents.callback_context import CallbackContext from ..events.event import Event @@ -385,7 +386,7 @@ async def _ensure_init(self): scopes=["https://www.googleapis.com/auth/cloud-platform"], ) client_info = gapic_client_info.ClientInfo( - user_agent="google-adk-bq-logger" + user_agent=f"google-adk-bq-logger/{version.__version__}" ) self._bq_client = bigquery.Client( project=self._project_id, credentials=creds, client_info=client_info From e1ca8740674f510dfe85a58973629af89b1637b3 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 19:28:42 -0300 Subject: [PATCH 07/10] improving function call logs with args --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index a09643f2fa..344c80ef04 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -790,9 +790,11 @@ async def before_model_callback( if p.text: parts_list.append({"type": "text", "text": p.text}) elif p.function_call: - parts_list.append( - {"type": "function_call", "name": p.function_call.name} - ) + parts_list.append({ + "type": "function_call", + "name": p.function_call.name, + "args": dict(p.function_call.args), + }) elif p.function_response: parts_list.append( {"type": "function_response", "name": p.function_response.name} From 685b212054037d67605c54213c38121e50119554 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 19:40:52 -0300 Subject: [PATCH 08/10] removing plugin_is_shutting_down, getting back datetime asserts --- .../plugins/test_bigquery_agent_analytics_plugin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index a252ed125a..7b6e79ead3 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -297,8 +297,6 @@ async def test_event_allowlist( await plugin.close() # Wait for write mock_write_client.append_rows.assert_called_once() mock_write_client.append_rows.reset_mock() - # Re-init plugin logic since close() shuts it down - plugin._is_shutting_down = False # REFACTOR: Use a fresh plugin instance for the denied case plugin_denied = ( @@ -334,8 +332,6 @@ async def test_event_denylist( plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) - # Reset for next call - plugin._is_shutting_down = False await plugin._ensure_init() mock_write_client.append_rows.reset_mock() @@ -648,6 +644,9 @@ async def test_on_event_callback_tool_call( content = json.loads(log_entry["content"]) assert content["raw_role"] == "MyTestAgent" assert content["tool_calls"] == ["get_weather"] + assert log_entry["timestamp"] == datetime.datetime( + 2025, 10, 22, 10, 0, 0, tzinfo=datetime.timezone.utc + ) @pytest.mark.asyncio async def test_on_event_callback_model_response( @@ -673,6 +672,9 @@ async def test_on_event_callback_model_response( content = json.loads(log_entry["content"]) assert content["text"] == "Hello there!" + assert log_entry["timestamp"] == datetime.datetime( + 2025, 10, 22, 11, 0, 0, tzinfo=datetime.timezone.utc + ) @pytest.mark.asyncio async def test_before_model_callback_logs_structure( From 342c62d6019121d3a63a266065de7da46f0b91db Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 19:47:03 -0300 Subject: [PATCH 09/10] Update src/google/adk/plugins/bigquery_agent_analytics_plugin.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 344c80ef04..531fa37c99 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -647,9 +647,9 @@ async def on_event_callback( for p in event.content.parts: if p.text: text_parts.append(p.text) - if p.function_call: + elif p.function_call: tool_calls.append(p.function_call.name) - if p.function_response: + elif p.function_response: tool_responses.append(p.function_response.name) payload = { From 5e53d85727f87c692b1a46af95548b4824926f35 Mon Sep 17 00:00:00 2001 From: Afonso Menegola Date: Wed, 26 Nov 2025 20:15:06 -0300 Subject: [PATCH 10/10] improving logs with file_Data and inline_data --- .../bigquery_agent_analytics_plugin.py | 43 ++++++++++++++----- .../test_bigquery_agent_analytics_plugin.py | 4 +- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 344c80ef04..ed2a6dd697 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -633,27 +633,48 @@ async def on_event_callback( 1. Event type (determined from event properties) 2. Event content (text, function calls, or responses) 3. Error messages (if any) - - The content is formatted as a structured JSON object based on the event type. - If individual string fields exceed `max_content_length`, they are truncated - to preserve the valid JSON structure. """ - # We try to extract text, but keep it simple for generic events - text_parts = [] + # Rename 'text_parts' to 'content_parts' since it holds dicts now + content_parts = [] + + # tool_calls and tool_responses might still be useful as separate summaries, + # or you can rely entirely on content_parts. keeping them for now: tool_calls = [] tool_responses = [] if event.content and event.content.parts: for p in event.content.parts: if p.text: - text_parts.append(p.text) - if p.function_call: + content_parts.append({"type": "text", "text": p.text}) + elif p.function_call: + content_parts.append({ + "type": "function_call", + "name": p.function_call.name, + "args": dict(p.function_call.args), + }) + # Optional: keep filling this if you want the high-level summary list tool_calls.append(p.function_call.name) - if p.function_response: + elif p.function_response: + content_parts.append( + {"type": "function_response", "name": p.function_response.name} + ) + # Optional: keep filling this if you want the high-level summary list tool_responses.append(p.function_response.name) + elif p.inline_data: + content_parts.append({ + "type": "inline_data", + "mime_type": p.inline_data.mime_type, + }) + elif p.file_data: + content_parts.append({ + "type": "file_data", + "mime_type": p.file_data.mime_type, + "file_uri": p.file_data.file_uri, + }) payload = { - "text": " ".join(text_parts) if text_parts else None, + # CHANGED: Do not join. Store the list of dicts. + "content_parts": content_parts if content_parts else None, "tool_calls": tool_calls if tool_calls else None, "tool_responses": tool_responses if tool_responses else None, "raw_role": event.author if event.author else None, @@ -844,7 +865,7 @@ async def after_model_callback( for p in llm_response.content.parts: if p.text: content_parts.append({"type": "text", "text": p.text}) - if p.function_call: + elif p.function_call: content_parts.append({ "type": "function_call", "name": p.function_call.name, diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 7b6e79ead3..6790faf680 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -671,7 +671,9 @@ async def test_on_event_callback_model_response( _assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent") content = json.loads(log_entry["content"]) - assert content["text"] == "Hello there!" + assert content["content_parts"][0]["type"] == "text" + assert content["content_parts"][0]["text"] == "Hello there!" + assert log_entry["timestamp"] == datetime.datetime( 2025, 10, 22, 11, 0, 0, tzinfo=datetime.timezone.utc )