diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..9b388533a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 380b666dc..7b86ad2cc 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -2,6 +2,10 @@ from __future__ import annotations +import concurrent.futures +import dataclasses +import functools +import inspect from contextlib import contextmanager from dataclasses import dataclass from typing import ( @@ -292,9 +296,55 @@ async def execute_activity( }, kind=opentelemetry.trace.SpanKind.SERVER, ): + # Propagate trace_context into synchronous activities running in + # ProcessPoolExecutor + is_async = inspect.iscoroutinefunction( + input.fn + ) or inspect.iscoroutinefunction( + input.fn.__call__ # type: ignore + ) + is_threadpool_executor = isinstance( + input.executor, concurrent.futures.ThreadPoolExecutor + ) + if not (is_async or is_threadpool_executor): + carrier: _CarrierDict = {} + default_text_map_propagator.inject(carrier) + input.fn = ActivityFnWithTraceContext(input.fn, carrier) + return await super().execute_activity(input) +@dataclasses.dataclass +class ActivityFnWithTraceContext: + """Wraps an activity function to inject trace context from a carrier. + + This wrapper is intended for sync activities executed in a process pool executor + to ensure tracing features like child spans, trace events, and log-correlation + works properly in the user's activity implementation. + """ + + fn: Callable[..., Any] + carrier: _CarrierDict + + def __post_init__(self): + """Post-initialization to ensure the function is wrapped correctly. + + Ensures the original function's metadata is preserved for reflection. + Downstream interceptors that may inspect the function's attributes, + like `__module__`, `__name__`, etc. (e.g. the `SentryInterceptor` + in the Python Samples.) + """ + functools.wraps(self.fn)(self) + + def __call__(self, *args: Any, **kwargs: Any): # noqa: D102 + trace_context = default_text_map_propagator.extract(self.carrier) + token = opentelemetry.context.attach(trace_context) + try: + return self.fn(*args, **kwargs) + finally: + opentelemetry.context.detach(token) + + class _InputWithHeaders(Protocol): headers: Mapping[str, temporalio.api.common.v1.Payload] diff --git a/tests/contrib/opentelemetry/helpers/_init__.py b/tests/contrib/opentelemetry/helpers/_init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/opentelemetry/helpers/reflection_interceptor.py b/tests/contrib/opentelemetry/helpers/reflection_interceptor.py new file mode 100644 index 000000000..ec599b36b --- /dev/null +++ b/tests/contrib/opentelemetry/helpers/reflection_interceptor.py @@ -0,0 +1,77 @@ +import dataclasses +import logging +import typing + +import temporalio.worker + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class InterceptedActivity: + class_name: str + name: typing.Optional[str] + qualname: typing.Optional[str] + module: typing.Optional[str] + annotations: typing.Dict[str, typing.Any] + docstring: typing.Optional[str] + + +class ReflectionInterceptor(temporalio.worker.Interceptor): + """Interceptor to check we haven't broken reflection when wrapping the activity.""" + + def __init__(self) -> None: + self._intercepted_activities: list[InterceptedActivity] = [] + + def get_intercepted_activities(self) -> typing.List[InterceptedActivity]: + """Get the list of intercepted activities.""" + return self._intercepted_activities + + def intercept_activity( + self, next: temporalio.worker.ActivityInboundInterceptor + ) -> temporalio.worker.ActivityInboundInterceptor: + """Method called for intercepting an activity. + + Args: + next: The underlying inbound interceptor this interceptor should + delegate to. + + Returns: + The new interceptor that will be used to for the activity. + """ + return _ReflectionActivityInboundInterceptor(next, self) + + +class _ReflectionActivityInboundInterceptor( + temporalio.worker.ActivityInboundInterceptor +): + def __init__( + self, + next: temporalio.worker.ActivityInboundInterceptor, + root: ReflectionInterceptor, + ) -> None: + super().__init__(next) + self.root = root + + async def execute_activity( + self, input: temporalio.worker.ExecuteActivityInput + ) -> typing.Any: + """Called to invoke the activity.""" + + try: + self.root._intercepted_activities.append( + InterceptedActivity( + class_name=input.fn.__class__.__name__, + name=getattr(input.fn, "__name__", None), + qualname=getattr(input.fn, "__qualname__", None), + module=getattr(input.fn, "__module__", None), + docstring=getattr(input.fn, "__doc__", None), + annotations=getattr(input.fn, "__annotations__", {}), + ) + ) + except AttributeError: + logger.exception( + "Activity function does not have expected attributes, skipping reflection." + ) + + return await self.next.execute_activity(input) diff --git a/tests/contrib/opentelemetry/helpers/tracing.py b/tests/contrib/opentelemetry/helpers/tracing.py new file mode 100644 index 000000000..cbf7cddae --- /dev/null +++ b/tests/contrib/opentelemetry/helpers/tracing.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import multiprocessing +import multiprocessing.managers +import threading +import typing +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union + +import opentelemetry.trace +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import ( + SpanExporter, + SpanExportResult, +) + + +@dataclass(frozen=True) +class SerialisableSpan: + """A serialisable, incomplete representation of a span for testing purposes.""" + + @dataclass(frozen=True) + class SpanContext: + trace_id: int + span_id: int + + @classmethod + def from_span_context( + cls, context: opentelemetry.trace.SpanContext + ) -> "SerialisableSpan.SpanContext": + return cls( + trace_id=context.trace_id, + span_id=context.span_id, + ) + + @classmethod + def from_optional_span_context( + cls, context: Optional[opentelemetry.trace.SpanContext] + ) -> Optional["SerialisableSpan.SpanContext"]: + if context is None: + return None + return cls.from_span_context(context) + + @dataclass(frozen=True) + class Link: + context: SerialisableSpan.SpanContext + attributes: Dict[str, Any] + + name: str + context: Optional[SpanContext] + parent: Optional[SpanContext] + attributes: Dict[str, Any] + links: Sequence[Link] + + @classmethod + def from_readable_span(cls, span: ReadableSpan) -> "SerialisableSpan": + return cls( + name=span.name, + context=cls.SpanContext.from_optional_span_context(span.context), + parent=cls.SpanContext.from_optional_span_context(span.parent), + attributes=dict(span.attributes or {}), + links=tuple( + cls.Link( + context=cls.SpanContext.from_span_context(link.context), + attributes=dict(span.attributes or {}), + ) + for link in span.links + ), + ) + + +def make_span_proxy_list( + manager: multiprocessing.managers.SyncManager, +) -> multiprocessing.managers.ListProxy[SerialisableSpan]: + """Create a list proxy to share `SerialisableSpan` across processes.""" + return manager.list() + + +class _ListProxySpanExporter(SpanExporter): + """Implementation of :class:`SpanExporter` that exports spans to a + list proxy created by a multiprocessing manager. + + This class is used for testing multiprocessing setups, as we can get access + to the finished spans from the parent process. + + In production, you would use `OTLPSpanExporter` or similar to export spans. + Tracing is designed to be distributed, the child process can push collected + spans directly to a collector or backend, which can reassemble the spans + into a single trace. + """ + + def __init__( + self, finished_spans: multiprocessing.managers.ListProxy[SerialisableSpan] + ) -> None: + self._finished_spans = finished_spans + self._stopped = False + self._lock = threading.Lock() + + def export(self, spans: typing.Sequence[ReadableSpan]) -> SpanExportResult: + if self._stopped: + return SpanExportResult.FAILURE + with self._lock: + # Note: ReadableSpan is not picklable, so convert to a DTO + # Note: we could use `span.to_json()` but there isn't a `from_json` + # and the serialisation isn't easily reversible, e.g. `parent` context + # is lost, span/trace IDs are transformed into strings + self._finished_spans.extend( + [SerialisableSpan.from_readable_span(span) for span in spans] + ) + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + self._stopped = True + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + +def dump_spans( + spans: Iterable[Union[ReadableSpan, SerialisableSpan]], + *, + parent_id: Optional[int] = None, + with_attributes: bool = True, + indent_depth: int = 0, +) -> List[str]: + ret: List[str] = [] + for span in spans: + if (not span.parent and parent_id is None) or ( + span.parent and span.parent.span_id == parent_id + ): + span_str = f"{' ' * indent_depth}{span.name}" + if with_attributes: + span_str += f" (attributes: {dict(span.attributes or {})})" + # Add links + if span.links: + span_links: List[str] = [] + for link in span.links: + for link_span in spans: + if ( + link_span.context + and link_span.context.span_id == link.context.span_id + ): + span_links.append(link_span.name) + span_str += f" (links: {', '.join(span_links)})" + # Signals can duplicate in rare situations, so we make sure not to + # re-add + if "Signal" in span_str and span_str in ret: + continue + ret.append(span_str) + ret += dump_spans( + spans, + parent_id=(span.context.span_id if span.context else None), + with_attributes=with_attributes, + indent_depth=indent_depth + 1, + ) + return ret diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index e42e6b977..2695f2b3a 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -1,14 +1,21 @@ from __future__ import annotations import asyncio +import concurrent.futures import logging +import multiprocessing +import multiprocessing.managers +import typing import uuid from dataclasses import dataclass from datetime import timedelta -from typing import Iterable, List, Optional +from typing import List, Optional -from opentelemetry.sdk.trace import ReadableSpan, TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor +import opentelemetry.trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, +) from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import get_tracer @@ -18,7 +25,17 @@ from temporalio.contrib.opentelemetry import TracingInterceptor from temporalio.contrib.opentelemetry import workflow as otel_workflow from temporalio.testing import WorkflowEnvironment -from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from temporalio.worker import SharedStateManager, UnsandboxedWorkflowRunner, Worker +from tests.contrib.opentelemetry.helpers.reflection_interceptor import ( + InterceptedActivity, + ReflectionInterceptor, +) +from tests.contrib.opentelemetry.helpers.tracing import ( + SerialisableSpan, + _ListProxySpanExporter, + dump_spans, + make_span_proxy_list, +) # Passing through because Python 3.9 has an import bug at # https://github.com/python/cpython/issues/91351 @@ -299,43 +316,6 @@ async def test_opentelemetry_tracing(client: Client, env: WorkflowEnvironment): ] -def dump_spans( - spans: Iterable[ReadableSpan], - *, - parent_id: Optional[int] = None, - with_attributes: bool = True, - indent_depth: int = 0, -) -> List[str]: - ret: List[str] = [] - for span in spans: - if (not span.parent and parent_id is None) or ( - span.parent and span.parent.span_id == parent_id - ): - span_str = f"{' ' * indent_depth}{span.name}" - if with_attributes: - span_str += f" (attributes: {dict(span.attributes or {})})" - # Add links - if span.links: - span_links: List[str] = [] - for link in span.links: - for link_span in spans: - if link_span.context.span_id == link.context.span_id: - span_links.append(link_span.name) - span_str += f" (links: {', '.join(span_links)})" - # Signals can duplicate in rare situations, so we make sure not to - # re-add - if "Signal" in span_str and span_str in ret: - continue - ret.append(span_str) - ret += dump_spans( - spans, - parent_id=span.context.span_id, - with_attributes=with_attributes, - indent_depth=indent_depth + 1, - ) - return ret - - @workflow.defn class SimpleWorkflow: @workflow.run @@ -392,3 +372,96 @@ async def test_opentelemetry_always_create_workflow_spans(client: Client): # * workflow failure and wft failure # * signal with start # * signal failure and wft failure from signal + + +@workflow.defn +class ActivityTracePropagationWorkflow: + @workflow.run + async def run(self) -> str: + retry_policy = RetryPolicy(initial_interval=timedelta(milliseconds=1)) + return await workflow.execute_activity( + sync_activity, + {}, + # TODO: Reduce to 10s - increasing to make debugging easier + start_to_close_timeout=timedelta(minutes=10), + retry_policy=retry_policy, + ) + + +@activity.defn +def sync_activity(param: typing.Any) -> str: + """An activity that uses tracing features.""" + inner_tracer = get_tracer("sync_activity") + with inner_tracer.start_as_current_span( + "child_span", + ): + return "done" + + +async def test_activity_trace_propagation( + client: Client, + env: WorkflowEnvironment, +): + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + + # Create a proxy list using the server process manager which we'll use + # to access finished spans in the process pool + manager = multiprocessing.Manager() + finished_spans_proxy = make_span_proxy_list(manager) + + # Create an interceptor to test we haven't broken reflection + reflection_interceptor = ReflectionInterceptor() + + # Create a worker with a process pool activity executor + async with Worker( + client, + task_queue=f"task_queue_{uuid.uuid4()}", + workflows=[ActivityTracePropagationWorkflow], + activities=[sync_activity], + interceptors=[TracingInterceptor(tracer), reflection_interceptor], + activity_executor=concurrent.futures.ProcessPoolExecutor( + max_workers=1, + initializer=activity_trace_propagation_initializer, + initargs=(finished_spans_proxy,), + ), + shared_state_manager=SharedStateManager.create_from_multiprocessing(manager), + ) as worker: + assert "done" == await client.execute_workflow( + ActivityTracePropagationWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # The dumped spans should include child spans created in the child process + spans = exporter.get_finished_spans() + tuple(finished_spans_proxy) + logging.debug("Spans:\n%s", "\n".join(dump_spans(spans, with_attributes=False))) + assert dump_spans(spans, with_attributes=False) == [ + "RunActivity:sync_activity", + " child_span", + ] + + # and the activity should still have the original attributes in downstream interceptors + assert reflection_interceptor.get_intercepted_activities() == [ + InterceptedActivity( + class_name="ActivityFnWithTraceContext", + name="sync_activity", + qualname="sync_activity", + module="tests.contrib.test_opentelemetry", + docstring="An activity that uses tracing features.", + annotations={"param": "typing.Any", "return": "str"}, + ) + ] + + +def activity_trace_propagation_initializer( + _finished_spans_proxy: multiprocessing.managers.ListProxy[SerialisableSpan], +) -> None: + """Initializer for the process pool worker to export spans to a shared list.""" + _exporter = _ListProxySpanExporter(_finished_spans_proxy) + _provider = TracerProvider() + _provider.add_span_processor(SimpleSpanProcessor(_exporter)) + opentelemetry.trace.set_tracer_provider(_provider)