diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index dfb010c332..0c12d39a9c 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -50,6 +50,8 @@ from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types +from opentelemetry import context +from opentelemetry import trace import pyarrow as pa from ..agents.callback_context import CallbackContext @@ -64,6 +66,9 @@ from ..agents.invocation_context import InvocationContext logger: logging.Logger = logging.getLogger("google_adk." + __name__) +tracer = trace.get_tracer( + "google.adk.plugins.bigquery_agent_analytics", __version__ +) # gRPC Error Codes @@ -105,69 +110,96 @@ def _format_content( return " | ".join(parts), truncated -def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]: +def _recursive_smart_truncate( + obj: Any, max_len: int, seen: Optional[set[int]] = None +) -> tuple[Any, bool]: """Recursively truncates string values within a dict or list. Args: obj: The object to truncate. max_len: Maximum length for string values. + seen: Set of object IDs visited in the current recursion stack. Returns: A tuple of (truncated_object, is_truncated). """ - if isinstance(obj, str): - if max_len != -1 and len(obj) > max_len: - return obj[:max_len] + "...[TRUNCATED]", True - return obj, False - elif isinstance(obj, dict): - truncated_any = False - # Use dict comprehension for potentially slightly better performance, - # but explicit loop is fine for clarity given recursive nature. - new_dict = {} - for k, v in obj.items(): - val, trunc = _recursive_smart_truncate(v, max_len) - if trunc: - truncated_any = True - new_dict[k] = val - return new_dict, truncated_any - elif isinstance(obj, (list, tuple)): - truncated_any = False - new_list = [] - # Explicit loop to handle flag propagation - for i in obj: - val, trunc = _recursive_smart_truncate(i, max_len) - if trunc: - truncated_any = True - new_list.append(val) - return type(obj)(new_list), truncated_any - elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): - # Convert dataclasses to dicts so they become valid JSON objects - return _recursive_smart_truncate(dataclasses.asdict(obj), max_len) - elif hasattr(obj, "model_dump") and callable(obj.model_dump): - # Pydantic v2 - try: - return _recursive_smart_truncate(obj.model_dump(), max_len) - except Exception: - pass - elif hasattr(obj, "dict") and callable(obj.dict): - # Pydantic v1 - try: - return _recursive_smart_truncate(obj.dict(), max_len) - except Exception: - pass - elif hasattr(obj, "to_dict") and callable(obj.to_dict): - # Common pattern for custom objects - try: - return _recursive_smart_truncate(obj.to_dict(), max_len) - except Exception: - pass - elif obj is None or isinstance(obj, (int, float, bool)): - # Basic types are safe - return obj, False + if seen is None: + seen = set() + + obj_id = id(obj) + if obj_id in seen: + return "[CIRCULAR_REFERENCE]", False + + # Track compound objects to detect cycles + is_compound = ( + isinstance(obj, (dict, list, tuple)) + or (dataclasses.is_dataclass(obj) and not isinstance(obj, type)) + or hasattr(obj, "model_dump") + or hasattr(obj, "dict") + or hasattr(obj, "to_dict") + ) + + if is_compound: + seen.add(obj_id) + + try: + if isinstance(obj, str): + if max_len != -1 and len(obj) > max_len: + return obj[:max_len] + "...[TRUNCATED]", True + return obj, False + elif isinstance(obj, dict): + truncated_any = False + # Use dict comprehension for potentially slightly better performance, + # but explicit loop is fine for clarity given recursive nature. + new_dict = {} + for k, v in obj.items(): + val, trunc = _recursive_smart_truncate(v, max_len, seen) + if trunc: + truncated_any = True + new_dict[k] = val + return new_dict, truncated_any + elif isinstance(obj, (list, tuple)): + truncated_any = False + new_list = [] + # Explicit loop to handle flag propagation + for i in obj: + val, trunc = _recursive_smart_truncate(i, max_len, seen) + if trunc: + truncated_any = True + new_list.append(val) + return type(obj)(new_list), truncated_any + elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): + # Manually iterate fields to preserve 'seen' context, avoiding dataclasses.asdict recursion + as_dict = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)} + return _recursive_smart_truncate(as_dict, max_len, seen) + elif hasattr(obj, "model_dump") and callable(obj.model_dump): + # Pydantic v2 + try: + return _recursive_smart_truncate(obj.model_dump(), max_len, seen) + except Exception: + pass + elif hasattr(obj, "dict") and callable(obj.dict): + # Pydantic v1 + try: + return _recursive_smart_truncate(obj.dict(), max_len, seen) + except Exception: + pass + elif hasattr(obj, "to_dict") and callable(obj.to_dict): + # Common pattern for custom objects + try: + return _recursive_smart_truncate(obj.to_dict(), max_len, seen) + except Exception: + pass + elif obj is None or isinstance(obj, (int, float, bool)): + # Basic types are safe + return obj, False - # Fallback for unknown types: Convert to string to ensure JSON validity - # We return string representation of the object, which is a valid JSON string value. - return str(obj), False + # Fallback for unknown types: Convert to string to ensure JSON validity + # We return string representation of the object, which is a valid JSON string value. + return str(obj), False + finally: + if is_compound: + seen.remove(obj_id) # --- PyArrow Helper Functions --- @@ -382,16 +414,26 @@ class BigQueryLoggerConfig: # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== -_trace_id_ctx = contextvars.ContextVar("_bq_analytics_trace_id", default=None) _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None ) -_span_stack_ctx = contextvars.ContextVar("_bq_analytics_span_stack", default=()) -_span_times_ctx = contextvars.ContextVar( - "_bq_analytics_span_times", default=None +_span_stack_ctx: contextvars.ContextVar[list[trace.Span]] = ( + contextvars.ContextVar("_bq_analytics_span_stack", default=None) ) -_span_first_token_times_ctx = contextvars.ContextVar( - "_bq_analytics_span_first_token_times", default=None +_span_token_stack_ctx: contextvars.ContextVar[list[trace.Token]] = ( + contextvars.ContextVar("_bq_analytics_span_token_stack", default=None) +) +_span_first_token_times_ctx: contextvars.ContextVar[dict[str, float]] = ( + contextvars.ContextVar("_bq_analytics_span_first_token_times", default=None) +) +_span_map_ctx: contextvars.ContextVar[dict[str, trace.Span]] = ( + contextvars.ContextVar("_bq_analytics_span_map", default=None) +) +_span_id_stack_ctx: contextvars.ContextVar[list[str]] = contextvars.ContextVar( + "_bq_analytics_span_id_stack", default=None +) +_span_start_time_ctx: contextvars.ContextVar[dict[str, int]] = ( + contextvars.ContextVar("_bq_analytics_span_start_time", default=None) ) @@ -400,75 +442,176 @@ class TraceManager: @staticmethod def init_trace(callback_context: CallbackContext) -> None: - if _trace_id_ctx.get() is None: - _trace_id_ctx.set(callback_context.invocation_id) - # Extract root agent name from invocation context + # Extract root agent name from invocation context if not set + if _root_agent_name_ctx.get() is None: try: root_agent = callback_context._invocation_context.agent.root_agent _root_agent_name_ctx.set(root_agent.name) except (AttributeError, ValueError): pass - _span_stack_ctx.set(()) - _span_times_ctx.set({}) + + if _span_first_token_times_ctx.get() is None: _span_first_token_times_ctx.set({}) + if _span_map_ctx.get() is None: + _span_map_ctx.set({}) + + if _span_start_time_ctx.get() is None: + _span_start_time_ctx.set({}) + @staticmethod def get_trace_id(callback_context: CallbackContext) -> Optional[str]: - # Try contextvars first - if trace_id := _trace_id_ctx.get(): - return trace_id - # Fallback to callback_context for existing tests/legacy flows - return callback_context.state.get("_bq_analytics_trace_id") + """Gets the trace ID from the current span or invocation_id.""" + # Prefer internal stack if available + stack = _span_stack_ctx.get() + if stack: + current_span = stack[-1] + if current_span.get_span_context().is_valid: + return format(current_span.get_span_context().trace_id, "032x") + + # Fallback to OTel context to satisfy "Trace Context Extraction" requirement + current_span = trace.get_current_span() + if current_span.get_span_context().is_valid: + return format(current_span.get_span_context().trace_id, "032x") + + return callback_context.invocation_id @staticmethod def push_span( - callback_context: CallbackContext, span_id: Optional[str] = None + callback_context: CallbackContext, span_name: Optional[str] = "adk-span" ) -> str: - # Ensure trace is initialized - if _trace_id_ctx.get() is None: - TraceManager.init_trace(callback_context) + """Starts a new span and pushes it onto the stack. - span_id = span_id or str(uuid.uuid4()) + If OTel is not configured (returning non-recording spans), a UUID fallback + is generated to ensure span_id and parent_span_id are populated in logs. + """ + # Ensure init_trace logic (root agent name) runs if needed + TraceManager.init_trace(callback_context) - stack = _span_stack_ctx.get() - new_stack = stack + (span_id,) + span = tracer.start_span(span_name) + token = context.attach(trace.set_span_in_context(span)) + + stack = _span_stack_ctx.get() or [] + new_stack = list(stack) + new_stack.append(span) _span_stack_ctx.set(new_stack) - times = dict(_span_times_ctx.get() or {}) - times[span_id] = time.time() - _span_times_ctx.set(times) - return span_id + token_stack = _span_token_stack_ctx.get() or [] + new_token_stack = list(token_stack) + new_token_stack.append(token) + _span_token_stack_ctx.set(new_token_stack) + + if span.get_span_context().is_valid: + span_id_str = format(span.get_span_context().span_id, "016x") + else: + # Fallback: Generate a UUID-based ID if OTel span is invalid (NoOp) + # using 32-char hex to avoid collision, treated as string in BQ. + span_id_str = uuid.uuid4().hex + + id_stack = _span_id_stack_ctx.get() or [] + new_id_stack = list(id_stack) + new_id_stack.append(span_id_str) + _span_id_stack_ctx.set(new_id_stack) + + span_map = _span_map_ctx.get() or {} + new_span_map = span_map.copy() + new_span_map[span_id_str] = span + _span_map_ctx.set(new_span_map) + + # Record start time manually for fallback support (NoOpSpan lacks start_time) + start_times = _span_start_time_ctx.get() or {} + new_start_times = start_times.copy() + new_start_times[span_id_str] = time.time_ns() + _span_start_time_ctx.set(new_start_times) + + return span_id_str @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: - stack = list(_span_stack_ctx.get()) - if not stack: + """Ends the current span and pops it from the stack.""" + stack = _span_stack_ctx.get() + token_stack = _span_token_stack_ctx.get() + + if not stack or not token_stack: return None, None - span_id = stack.pop() - _span_stack_ctx.set(tuple(stack)) - times_dict = dict(_span_times_ctx.get() or {}) - start_time = times_dict.pop(span_id, None) - _span_times_ctx.set(times_dict) + new_stack = list(stack) + new_token_stack = list(token_stack) - ft_dict = dict(_span_first_token_times_ctx.get() or {}) - ft_dict.pop(span_id, None) - _span_first_token_times_ctx.set(ft_dict) + span = new_stack.pop() + token = new_token_stack.pop() + + _span_stack_ctx.set(new_stack) + _span_token_stack_ctx.set(new_token_stack) + + # Pop from ID stack regarding fallback support + id_stack = _span_id_stack_ctx.get() + if id_stack: + new_id_stack = list(id_stack) + span_id = new_id_stack.pop() + _span_id_stack_ctx.set(new_id_stack) + else: + # Should not happen if stacks are in sync, but robust fallback: + if span.get_span_context().is_valid: + span_id = format(span.get_span_context().span_id, "016x") + else: + span_id = "unknown-id" + + duration_ms = None + # Try getting start time from OTel span first, then fallback to manual tracking + if hasattr(span, "start_time") and span.start_time: + duration_ms = int((time.time_ns() - span.start_time) / 1_000_000) + else: + start_times = _span_start_time_ctx.get() + if start_times and span_id in start_times: + start_ns = start_times[span_id] + duration_ms = int((time.time_ns() - start_ns) / 1_000_000) + + span.end() + context.detach(token) + + first_tokens = _span_first_token_times_ctx.get() + if first_tokens: + # Copy to modify + new_first_tokens = first_tokens.copy() + new_first_tokens.pop(span_id, None) + _span_first_token_times_ctx.set(new_first_tokens) + + span_map = _span_map_ctx.get() + if span_map: + new_span_map = span_map.copy() + new_span_map.pop(span_id, None) + _span_map_ctx.set(new_span_map) + + start_times = _span_start_time_ctx.get() + if start_times: + new_start_times = start_times.copy() + new_start_times.pop(span_id, None) + _span_start_time_ctx.set(new_start_times) - duration_ms = int((time.time() - start_time) * 1000) if start_time else None return span_id, duration_ms @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: - stack = _span_stack_ctx.get() - if not stack: - return None, None - return stack[-1], (stack[-2] if len(stack) > 1 else None) + """Gets current span_id and parent span_id from OTEL context or fallback stack.""" + # Use internal ID stack for robust resolution (handling both OTel and fallback IDs) + id_stack = _span_id_stack_ctx.get() + if id_stack: + span_id = id_stack[-1] + parent_id = None + if len(id_stack) > 1: + parent_id = id_stack[-2] + return span_id, parent_id + + return None, None @staticmethod def get_current_span_id() -> Optional[str]: - stack = _span_stack_ctx.get() - return stack[-1] if stack else None + """Gets current span_id from OTEL context or fallback stack.""" + id_stack = _span_id_stack_ctx.get() + if id_stack: + return id_stack[-1] + return None @staticmethod def get_root_agent_name() -> Optional[str]: @@ -476,25 +619,40 @@ def get_root_agent_name() -> Optional[str]: @staticmethod def get_start_time(span_id: str) -> Optional[float]: - times = _span_times_ctx.get() - return times.get(span_id) if times else None + """Gets start time of a span by ID.""" + # Try OTel Object first + span_map = _span_map_ctx.get() + if span_map: + span = span_map.get(span_id) + if ( + span + and span.get_span_context().is_valid + and hasattr(span, "start_time") + ): + return span.start_time / 1_000_000_000.0 + + # Fallback to manual start time + start_times = _span_start_time_ctx.get() + if start_times and span_id in start_times: + return start_times[span_id] / 1_000_000_000.0 + + return None @staticmethod def record_first_token(span_id: str) -> bool: - """Records the current time as first token time if not already recorded. + """Records the current time as first token time if not already recorded.""" + first_tokens = _span_first_token_times_ctx.get() - Returns: - True if this was the first token (newly recorded), False otherwise. - """ - first_tokens = dict(_span_first_token_times_ctx.get() or {}) if span_id not in first_tokens: - first_tokens[span_id] = time.time() - _span_first_token_times_ctx.set(first_tokens) + new_first_tokens = first_tokens.copy() + new_first_tokens[span_id] = time.time() + _span_first_token_times_ctx.set(new_first_tokens) return True return False @staticmethod def get_first_token_time(span_id: str) -> Optional[float]: + """Gets the recorded first token time.""" first_tokens = _span_first_token_times_ctx.get() return first_tokens.get(span_id) if first_tokens else None @@ -1800,7 +1958,7 @@ async def before_agent_callback( callback_context: The callback context. """ TraceManager.init_trace(callback_context) - TraceManager.push_span(callback_context) + TraceManager.push_span(callback_context, "agent") await self._log_event( "AGENT_STARTING", callback_context, @@ -1870,7 +2028,7 @@ async def before_model_callback( # Merge any additional kwargs into attributes attributes.update(kwargs) - TraceManager.push_span(callback_context) + TraceManager.push_span(callback_context, "llm") await self._log_event( "LLM_REQUEST", callback_context, @@ -2025,7 +2183,7 @@ async def before_tool_callback( tool_args, self.config.max_content_length ) content_dict = {"tool": tool.name, "args": args_truncated} - TraceManager.push_span(tool_context) + TraceManager.push_span(tool_context, "tool") await self._log_event( "TOOL_STARTING", tool_context, diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index ed19ee7256..5f2474c0a1 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -37,6 +37,7 @@ from google.cloud import bigquery from google.cloud import exceptions as cloud_exceptions from google.genai import types +from opentelemetry import trace import pyarrow as pa import pytest @@ -502,28 +503,34 @@ async def test_concurrent_span_management( bigquery_agent_analytics_plugin.TraceManager.init_trace(callback_context) async def branch_1(): - bigquery_agent_analytics_plugin.TraceManager.push_span( - callback_context, span_id="span-1" + s_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, span_name="span-1" ) await asyncio.sleep(0.02) - s_id = bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + current_s_id = ( + bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + ) + assert s_id == current_s_id bigquery_agent_analytics_plugin.TraceManager.pop_span() return s_id async def branch_2(): - bigquery_agent_analytics_plugin.TraceManager.push_span( - callback_context, span_id="span-2" + s_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, span_name="span-2" ) await asyncio.sleep(0.02) - s_id = bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + current_s_id = ( + bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + ) + assert s_id == current_s_id bigquery_agent_analytics_plugin.TraceManager.pop_span() return s_id # Run concurrently results = await asyncio.gather(branch_1(), branch_2()) # If they shared the same list/dict, they would interfere. - assert "span-1" in results - assert "span-2" in results + assert results[0] is not None + assert results[1] is not None assert results[0] != results[1] @pytest.mark.asyncio @@ -1953,3 +1960,104 @@ class LocalIncident: content_json = json.loads(log_entry["content"]) assert content_json["result"]["id"] == "inc-123" assert content_json["result"]["kpi_missed"][0]["kpi"] == "latency" + + @pytest.mark.asyncio + async def test_otel_integration( + self, + callback_context, + ): + """Verifies OpenTelemetry integration in TraceManager.""" + # Mock the tracer and span + mock_tracer = mock.Mock() + mock_span = mock.Mock() + mock_context = mock.Mock() + + # Setup mock IDs (128-bit trace_id, 64-bit span_id) + trace_id_int = 0x12345678123456781234567812345678 + span_id_int = 0x1234567812345678 + + mock_context.trace_id = trace_id_int + mock_context.span_id = span_id_int + mock_context.is_valid = True + + mock_span.get_span_context.return_value = mock_context + mock_span.start_time = 1234567890000000000 # Mock start time in ns + mock_tracer.start_span.return_value = mock_span + + # Patch the global tracer in the plugin module + with mock.patch( + "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", mock_tracer + ): + # Test push_span + span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + mock_tracer.start_span.assert_called_with("test_span") + assert span_id == format(span_id_int, "016x") + + # Test get_trace_id + # We need to mock trace.get_current_span() to return our mock span + # because push_span calls trace.attach(), which affects the global context + with mock.patch( + "opentelemetry.trace.get_current_span", return_value=mock_span + ): + trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + assert trace_id == format(trace_id_int, "032x") + + # Test pop_span + # pop_span calls span.end() + bigquery_agent_analytics_plugin.TraceManager.pop_span() + mock_span.end.assert_called_once() + + @pytest.mark.asyncio + async def test_otel_integration_real_provider(self, callback_context): + """Verifies TraceManager with a real OpenTelemetry TracerProvider.""" + # Setup OTEL with in-memory exporter + # pylint: disable=g-import-not-at-top + from opentelemetry.sdk import trace as trace_sdk + from opentelemetry.sdk.trace import export as trace_export + from opentelemetry.sdk.trace.export import in_memory_span_exporter + + # pylint: enable=g-import-not-at-top + + provider = trace_sdk.TracerProvider() + exporter = in_memory_span_exporter.InMemorySpanExporter() + processor = trace_export.SimpleSpanProcessor(exporter) + provider.add_span_processor(processor) + tracer = provider.get_tracer("test_tracer") + + # Patch the global tracer in the plugin module + with mock.patch( + "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", tracer + ): + # 1. Start a span + span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + # Verify a span was started but not ended + current_spans = exporter.get_finished_spans() + assert not current_spans + + # Verify we can retrieve the trace ID + trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + assert trace_id is not None + + # 2. End the span + popped_span_id, _ = ( + bigquery_agent_analytics_plugin.TraceManager.pop_span() + ) + + assert popped_span_id == span_id + + # Verify span is now finished and exported + finished_spans = exporter.get_finished_spans() + assert len(finished_spans) == 1 + assert finished_spans[0].name == "test_span" + assert format(finished_spans[0].context.span_id, "016x") == span_id + assert format(finished_spans[0].context.trace_id, "032x") == trace_id