From ab89d1283430041afb303834749869e9ee331721 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 21 Jan 2026 01:02:11 -0800 Subject: [PATCH] refactor(plugins)!: use OpenTelemetry for BigQuery plugin tracing This refactors the BigQueryAgentAnalyticsPlugin to use the standard OpenTelemetry API for trace and span ID generation and propagation, replacing the custom ContextVar implementation. Key changes: - Utilizes `opentelemetry.trace` for starting/ending spans. - Correctly uses `opentelemetry.context` for context attachment and detachment. - Span information is now derived from the OpenTelemetry context when available. - Added a fallback mechanism to ensure span_id and parent_span_id are still populated if the OpenTelemetry SDK is not initialized. To get standard OpenTelemetry trace information in BigQuery logs, users should install `opentelemetry-sdk` and initialize a global `TracerProvider` in their application *before* initializing ADK components. Example minimal initialization: ```python # Install: pip install opentelemetry-sdk from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider trace.set_tracer_provider(TracerProvider()) ``` PiperOrigin-RevId: 858965562 --- .../bigquery_agent_analytics_plugin.py | 376 +++++++++++++----- .../test_bigquery_agent_analytics_plugin.py | 124 +++++- 2 files changed, 383 insertions(+), 117 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index dfb010c332..0c12d39a9c 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -50,6 +50,8 @@ from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types +from opentelemetry import context +from opentelemetry import trace import pyarrow as pa from ..agents.callback_context import CallbackContext @@ -64,6 +66,9 @@ from ..agents.invocation_context import InvocationContext logger: logging.Logger = logging.getLogger("google_adk." + __name__) +tracer = trace.get_tracer( + "google.adk.plugins.bigquery_agent_analytics", __version__ +) # gRPC Error Codes @@ -105,69 +110,96 @@ def _format_content( return " | ".join(parts), truncated -def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]: +def _recursive_smart_truncate( + obj: Any, max_len: int, seen: Optional[set[int]] = None +) -> tuple[Any, bool]: """Recursively truncates string values within a dict or list. Args: obj: The object to truncate. max_len: Maximum length for string values. + seen: Set of object IDs visited in the current recursion stack. Returns: A tuple of (truncated_object, is_truncated). """ - if isinstance(obj, str): - if max_len != -1 and len(obj) > max_len: - return obj[:max_len] + "...[TRUNCATED]", True - return obj, False - elif isinstance(obj, dict): - truncated_any = False - # Use dict comprehension for potentially slightly better performance, - # but explicit loop is fine for clarity given recursive nature. - new_dict = {} - for k, v in obj.items(): - val, trunc = _recursive_smart_truncate(v, max_len) - if trunc: - truncated_any = True - new_dict[k] = val - return new_dict, truncated_any - elif isinstance(obj, (list, tuple)): - truncated_any = False - new_list = [] - # Explicit loop to handle flag propagation - for i in obj: - val, trunc = _recursive_smart_truncate(i, max_len) - if trunc: - truncated_any = True - new_list.append(val) - return type(obj)(new_list), truncated_any - elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): - # Convert dataclasses to dicts so they become valid JSON objects - return _recursive_smart_truncate(dataclasses.asdict(obj), max_len) - elif hasattr(obj, "model_dump") and callable(obj.model_dump): - # Pydantic v2 - try: - return _recursive_smart_truncate(obj.model_dump(), max_len) - except Exception: - pass - elif hasattr(obj, "dict") and callable(obj.dict): - # Pydantic v1 - try: - return _recursive_smart_truncate(obj.dict(), max_len) - except Exception: - pass - elif hasattr(obj, "to_dict") and callable(obj.to_dict): - # Common pattern for custom objects - try: - return _recursive_smart_truncate(obj.to_dict(), max_len) - except Exception: - pass - elif obj is None or isinstance(obj, (int, float, bool)): - # Basic types are safe - return obj, False + if seen is None: + seen = set() + + obj_id = id(obj) + if obj_id in seen: + return "[CIRCULAR_REFERENCE]", False + + # Track compound objects to detect cycles + is_compound = ( + isinstance(obj, (dict, list, tuple)) + or (dataclasses.is_dataclass(obj) and not isinstance(obj, type)) + or hasattr(obj, "model_dump") + or hasattr(obj, "dict") + or hasattr(obj, "to_dict") + ) + + if is_compound: + seen.add(obj_id) + + try: + if isinstance(obj, str): + if max_len != -1 and len(obj) > max_len: + return obj[:max_len] + "...[TRUNCATED]", True + return obj, False + elif isinstance(obj, dict): + truncated_any = False + # Use dict comprehension for potentially slightly better performance, + # but explicit loop is fine for clarity given recursive nature. + new_dict = {} + for k, v in obj.items(): + val, trunc = _recursive_smart_truncate(v, max_len, seen) + if trunc: + truncated_any = True + new_dict[k] = val + return new_dict, truncated_any + elif isinstance(obj, (list, tuple)): + truncated_any = False + new_list = [] + # Explicit loop to handle flag propagation + for i in obj: + val, trunc = _recursive_smart_truncate(i, max_len, seen) + if trunc: + truncated_any = True + new_list.append(val) + return type(obj)(new_list), truncated_any + elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): + # Manually iterate fields to preserve 'seen' context, avoiding dataclasses.asdict recursion + as_dict = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)} + return _recursive_smart_truncate(as_dict, max_len, seen) + elif hasattr(obj, "model_dump") and callable(obj.model_dump): + # Pydantic v2 + try: + return _recursive_smart_truncate(obj.model_dump(), max_len, seen) + except Exception: + pass + elif hasattr(obj, "dict") and callable(obj.dict): + # Pydantic v1 + try: + return _recursive_smart_truncate(obj.dict(), max_len, seen) + except Exception: + pass + elif hasattr(obj, "to_dict") and callable(obj.to_dict): + # Common pattern for custom objects + try: + return _recursive_smart_truncate(obj.to_dict(), max_len, seen) + except Exception: + pass + elif obj is None or isinstance(obj, (int, float, bool)): + # Basic types are safe + return obj, False - # Fallback for unknown types: Convert to string to ensure JSON validity - # We return string representation of the object, which is a valid JSON string value. - return str(obj), False + # Fallback for unknown types: Convert to string to ensure JSON validity + # We return string representation of the object, which is a valid JSON string value. + return str(obj), False + finally: + if is_compound: + seen.remove(obj_id) # --- PyArrow Helper Functions --- @@ -382,16 +414,26 @@ class BigQueryLoggerConfig: # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== -_trace_id_ctx = contextvars.ContextVar("_bq_analytics_trace_id", default=None) _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None ) -_span_stack_ctx = contextvars.ContextVar("_bq_analytics_span_stack", default=()) -_span_times_ctx = contextvars.ContextVar( - "_bq_analytics_span_times", default=None +_span_stack_ctx: contextvars.ContextVar[list[trace.Span]] = ( + contextvars.ContextVar("_bq_analytics_span_stack", default=None) ) -_span_first_token_times_ctx = contextvars.ContextVar( - "_bq_analytics_span_first_token_times", default=None +_span_token_stack_ctx: contextvars.ContextVar[list[trace.Token]] = ( + contextvars.ContextVar("_bq_analytics_span_token_stack", default=None) +) +_span_first_token_times_ctx: contextvars.ContextVar[dict[str, float]] = ( + contextvars.ContextVar("_bq_analytics_span_first_token_times", default=None) +) +_span_map_ctx: contextvars.ContextVar[dict[str, trace.Span]] = ( + contextvars.ContextVar("_bq_analytics_span_map", default=None) +) +_span_id_stack_ctx: contextvars.ContextVar[list[str]] = contextvars.ContextVar( + "_bq_analytics_span_id_stack", default=None +) +_span_start_time_ctx: contextvars.ContextVar[dict[str, int]] = ( + contextvars.ContextVar("_bq_analytics_span_start_time", default=None) ) @@ -400,75 +442,176 @@ class TraceManager: @staticmethod def init_trace(callback_context: CallbackContext) -> None: - if _trace_id_ctx.get() is None: - _trace_id_ctx.set(callback_context.invocation_id) - # Extract root agent name from invocation context + # Extract root agent name from invocation context if not set + if _root_agent_name_ctx.get() is None: try: root_agent = callback_context._invocation_context.agent.root_agent _root_agent_name_ctx.set(root_agent.name) except (AttributeError, ValueError): pass - _span_stack_ctx.set(()) - _span_times_ctx.set({}) + + if _span_first_token_times_ctx.get() is None: _span_first_token_times_ctx.set({}) + if _span_map_ctx.get() is None: + _span_map_ctx.set({}) + + if _span_start_time_ctx.get() is None: + _span_start_time_ctx.set({}) + @staticmethod def get_trace_id(callback_context: CallbackContext) -> Optional[str]: - # Try contextvars first - if trace_id := _trace_id_ctx.get(): - return trace_id - # Fallback to callback_context for existing tests/legacy flows - return callback_context.state.get("_bq_analytics_trace_id") + """Gets the trace ID from the current span or invocation_id.""" + # Prefer internal stack if available + stack = _span_stack_ctx.get() + if stack: + current_span = stack[-1] + if current_span.get_span_context().is_valid: + return format(current_span.get_span_context().trace_id, "032x") + + # Fallback to OTel context to satisfy "Trace Context Extraction" requirement + current_span = trace.get_current_span() + if current_span.get_span_context().is_valid: + return format(current_span.get_span_context().trace_id, "032x") + + return callback_context.invocation_id @staticmethod def push_span( - callback_context: CallbackContext, span_id: Optional[str] = None + callback_context: CallbackContext, span_name: Optional[str] = "adk-span" ) -> str: - # Ensure trace is initialized - if _trace_id_ctx.get() is None: - TraceManager.init_trace(callback_context) + """Starts a new span and pushes it onto the stack. - span_id = span_id or str(uuid.uuid4()) + If OTel is not configured (returning non-recording spans), a UUID fallback + is generated to ensure span_id and parent_span_id are populated in logs. + """ + # Ensure init_trace logic (root agent name) runs if needed + TraceManager.init_trace(callback_context) - stack = _span_stack_ctx.get() - new_stack = stack + (span_id,) + span = tracer.start_span(span_name) + token = context.attach(trace.set_span_in_context(span)) + + stack = _span_stack_ctx.get() or [] + new_stack = list(stack) + new_stack.append(span) _span_stack_ctx.set(new_stack) - times = dict(_span_times_ctx.get() or {}) - times[span_id] = time.time() - _span_times_ctx.set(times) - return span_id + token_stack = _span_token_stack_ctx.get() or [] + new_token_stack = list(token_stack) + new_token_stack.append(token) + _span_token_stack_ctx.set(new_token_stack) + + if span.get_span_context().is_valid: + span_id_str = format(span.get_span_context().span_id, "016x") + else: + # Fallback: Generate a UUID-based ID if OTel span is invalid (NoOp) + # using 32-char hex to avoid collision, treated as string in BQ. + span_id_str = uuid.uuid4().hex + + id_stack = _span_id_stack_ctx.get() or [] + new_id_stack = list(id_stack) + new_id_stack.append(span_id_str) + _span_id_stack_ctx.set(new_id_stack) + + span_map = _span_map_ctx.get() or {} + new_span_map = span_map.copy() + new_span_map[span_id_str] = span + _span_map_ctx.set(new_span_map) + + # Record start time manually for fallback support (NoOpSpan lacks start_time) + start_times = _span_start_time_ctx.get() or {} + new_start_times = start_times.copy() + new_start_times[span_id_str] = time.time_ns() + _span_start_time_ctx.set(new_start_times) + + return span_id_str @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: - stack = list(_span_stack_ctx.get()) - if not stack: + """Ends the current span and pops it from the stack.""" + stack = _span_stack_ctx.get() + token_stack = _span_token_stack_ctx.get() + + if not stack or not token_stack: return None, None - span_id = stack.pop() - _span_stack_ctx.set(tuple(stack)) - times_dict = dict(_span_times_ctx.get() or {}) - start_time = times_dict.pop(span_id, None) - _span_times_ctx.set(times_dict) + new_stack = list(stack) + new_token_stack = list(token_stack) - ft_dict = dict(_span_first_token_times_ctx.get() or {}) - ft_dict.pop(span_id, None) - _span_first_token_times_ctx.set(ft_dict) + span = new_stack.pop() + token = new_token_stack.pop() + + _span_stack_ctx.set(new_stack) + _span_token_stack_ctx.set(new_token_stack) + + # Pop from ID stack regarding fallback support + id_stack = _span_id_stack_ctx.get() + if id_stack: + new_id_stack = list(id_stack) + span_id = new_id_stack.pop() + _span_id_stack_ctx.set(new_id_stack) + else: + # Should not happen if stacks are in sync, but robust fallback: + if span.get_span_context().is_valid: + span_id = format(span.get_span_context().span_id, "016x") + else: + span_id = "unknown-id" + + duration_ms = None + # Try getting start time from OTel span first, then fallback to manual tracking + if hasattr(span, "start_time") and span.start_time: + duration_ms = int((time.time_ns() - span.start_time) / 1_000_000) + else: + start_times = _span_start_time_ctx.get() + if start_times and span_id in start_times: + start_ns = start_times[span_id] + duration_ms = int((time.time_ns() - start_ns) / 1_000_000) + + span.end() + context.detach(token) + + first_tokens = _span_first_token_times_ctx.get() + if first_tokens: + # Copy to modify + new_first_tokens = first_tokens.copy() + new_first_tokens.pop(span_id, None) + _span_first_token_times_ctx.set(new_first_tokens) + + span_map = _span_map_ctx.get() + if span_map: + new_span_map = span_map.copy() + new_span_map.pop(span_id, None) + _span_map_ctx.set(new_span_map) + + start_times = _span_start_time_ctx.get() + if start_times: + new_start_times = start_times.copy() + new_start_times.pop(span_id, None) + _span_start_time_ctx.set(new_start_times) - duration_ms = int((time.time() - start_time) * 1000) if start_time else None return span_id, duration_ms @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: - stack = _span_stack_ctx.get() - if not stack: - return None, None - return stack[-1], (stack[-2] if len(stack) > 1 else None) + """Gets current span_id and parent span_id from OTEL context or fallback stack.""" + # Use internal ID stack for robust resolution (handling both OTel and fallback IDs) + id_stack = _span_id_stack_ctx.get() + if id_stack: + span_id = id_stack[-1] + parent_id = None + if len(id_stack) > 1: + parent_id = id_stack[-2] + return span_id, parent_id + + return None, None @staticmethod def get_current_span_id() -> Optional[str]: - stack = _span_stack_ctx.get() - return stack[-1] if stack else None + """Gets current span_id from OTEL context or fallback stack.""" + id_stack = _span_id_stack_ctx.get() + if id_stack: + return id_stack[-1] + return None @staticmethod def get_root_agent_name() -> Optional[str]: @@ -476,25 +619,40 @@ def get_root_agent_name() -> Optional[str]: @staticmethod def get_start_time(span_id: str) -> Optional[float]: - times = _span_times_ctx.get() - return times.get(span_id) if times else None + """Gets start time of a span by ID.""" + # Try OTel Object first + span_map = _span_map_ctx.get() + if span_map: + span = span_map.get(span_id) + if ( + span + and span.get_span_context().is_valid + and hasattr(span, "start_time") + ): + return span.start_time / 1_000_000_000.0 + + # Fallback to manual start time + start_times = _span_start_time_ctx.get() + if start_times and span_id in start_times: + return start_times[span_id] / 1_000_000_000.0 + + return None @staticmethod def record_first_token(span_id: str) -> bool: - """Records the current time as first token time if not already recorded. + """Records the current time as first token time if not already recorded.""" + first_tokens = _span_first_token_times_ctx.get() - Returns: - True if this was the first token (newly recorded), False otherwise. - """ - first_tokens = dict(_span_first_token_times_ctx.get() or {}) if span_id not in first_tokens: - first_tokens[span_id] = time.time() - _span_first_token_times_ctx.set(first_tokens) + new_first_tokens = first_tokens.copy() + new_first_tokens[span_id] = time.time() + _span_first_token_times_ctx.set(new_first_tokens) return True return False @staticmethod def get_first_token_time(span_id: str) -> Optional[float]: + """Gets the recorded first token time.""" first_tokens = _span_first_token_times_ctx.get() return first_tokens.get(span_id) if first_tokens else None @@ -1800,7 +1958,7 @@ async def before_agent_callback( callback_context: The callback context. """ TraceManager.init_trace(callback_context) - TraceManager.push_span(callback_context) + TraceManager.push_span(callback_context, "agent") await self._log_event( "AGENT_STARTING", callback_context, @@ -1870,7 +2028,7 @@ async def before_model_callback( # Merge any additional kwargs into attributes attributes.update(kwargs) - TraceManager.push_span(callback_context) + TraceManager.push_span(callback_context, "llm") await self._log_event( "LLM_REQUEST", callback_context, @@ -2025,7 +2183,7 @@ async def before_tool_callback( tool_args, self.config.max_content_length ) content_dict = {"tool": tool.name, "args": args_truncated} - TraceManager.push_span(tool_context) + TraceManager.push_span(tool_context, "tool") await self._log_event( "TOOL_STARTING", tool_context, diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index ed19ee7256..5f2474c0a1 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -37,6 +37,7 @@ from google.cloud import bigquery from google.cloud import exceptions as cloud_exceptions from google.genai import types +from opentelemetry import trace import pyarrow as pa import pytest @@ -502,28 +503,34 @@ async def test_concurrent_span_management( bigquery_agent_analytics_plugin.TraceManager.init_trace(callback_context) async def branch_1(): - bigquery_agent_analytics_plugin.TraceManager.push_span( - callback_context, span_id="span-1" + s_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, span_name="span-1" ) await asyncio.sleep(0.02) - s_id = bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + current_s_id = ( + bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + ) + assert s_id == current_s_id bigquery_agent_analytics_plugin.TraceManager.pop_span() return s_id async def branch_2(): - bigquery_agent_analytics_plugin.TraceManager.push_span( - callback_context, span_id="span-2" + s_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, span_name="span-2" ) await asyncio.sleep(0.02) - s_id = bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + current_s_id = ( + bigquery_agent_analytics_plugin.TraceManager.get_current_span_id() + ) + assert s_id == current_s_id bigquery_agent_analytics_plugin.TraceManager.pop_span() return s_id # Run concurrently results = await asyncio.gather(branch_1(), branch_2()) # If they shared the same list/dict, they would interfere. - assert "span-1" in results - assert "span-2" in results + assert results[0] is not None + assert results[1] is not None assert results[0] != results[1] @pytest.mark.asyncio @@ -1953,3 +1960,104 @@ class LocalIncident: content_json = json.loads(log_entry["content"]) assert content_json["result"]["id"] == "inc-123" assert content_json["result"]["kpi_missed"][0]["kpi"] == "latency" + + @pytest.mark.asyncio + async def test_otel_integration( + self, + callback_context, + ): + """Verifies OpenTelemetry integration in TraceManager.""" + # Mock the tracer and span + mock_tracer = mock.Mock() + mock_span = mock.Mock() + mock_context = mock.Mock() + + # Setup mock IDs (128-bit trace_id, 64-bit span_id) + trace_id_int = 0x12345678123456781234567812345678 + span_id_int = 0x1234567812345678 + + mock_context.trace_id = trace_id_int + mock_context.span_id = span_id_int + mock_context.is_valid = True + + mock_span.get_span_context.return_value = mock_context + mock_span.start_time = 1234567890000000000 # Mock start time in ns + mock_tracer.start_span.return_value = mock_span + + # Patch the global tracer in the plugin module + with mock.patch( + "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", mock_tracer + ): + # Test push_span + span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + mock_tracer.start_span.assert_called_with("test_span") + assert span_id == format(span_id_int, "016x") + + # Test get_trace_id + # We need to mock trace.get_current_span() to return our mock span + # because push_span calls trace.attach(), which affects the global context + with mock.patch( + "opentelemetry.trace.get_current_span", return_value=mock_span + ): + trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + assert trace_id == format(trace_id_int, "032x") + + # Test pop_span + # pop_span calls span.end() + bigquery_agent_analytics_plugin.TraceManager.pop_span() + mock_span.end.assert_called_once() + + @pytest.mark.asyncio + async def test_otel_integration_real_provider(self, callback_context): + """Verifies TraceManager with a real OpenTelemetry TracerProvider.""" + # Setup OTEL with in-memory exporter + # pylint: disable=g-import-not-at-top + from opentelemetry.sdk import trace as trace_sdk + from opentelemetry.sdk.trace import export as trace_export + from opentelemetry.sdk.trace.export import in_memory_span_exporter + + # pylint: enable=g-import-not-at-top + + provider = trace_sdk.TracerProvider() + exporter = in_memory_span_exporter.InMemorySpanExporter() + processor = trace_export.SimpleSpanProcessor(exporter) + provider.add_span_processor(processor) + tracer = provider.get_tracer("test_tracer") + + # Patch the global tracer in the plugin module + with mock.patch( + "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", tracer + ): + # 1. Start a span + span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + # Verify a span was started but not ended + current_spans = exporter.get_finished_spans() + assert not current_spans + + # Verify we can retrieve the trace ID + trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + assert trace_id is not None + + # 2. End the span + popped_span_id, _ = ( + bigquery_agent_analytics_plugin.TraceManager.pop_span() + ) + + assert popped_span_id == span_id + + # Verify span is now finished and exported + finished_spans = exporter.get_finished_spans() + assert len(finished_spans) == 1 + assert finished_spans[0].name == "test_span" + assert format(finished_spans[0].context.span_id, "016x") == span_id + assert format(finished_spans[0].context.trace_id, "032x") == trace_id