From 800c28142316dcbf5c1337ae14870a4857361381 Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 20 Feb 2026 21:24:30 -0500 Subject: [PATCH 1/3] build: Add grpc-interceptor dependency Required by the new task dispatch interceptor scheme to bridge Wool interceptors to the gRPC async server interceptor interface. --- wool/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/wool/pyproject.toml b/wool/pyproject.toml index 341a508..dca3909 100644 --- a/wool/pyproject.toml +++ b/wool/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ ] dependencies = [ "cloudpickle", + "grpc-interceptor", "grpcio>=1.76.0", "portalocker", "protobuf", From 75c82a115729c739b09e88b8bb149cabd496231e Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 20 Feb 2026 21:25:02 -0500 Subject: [PATCH 2/3] feat: Add customizable task dispatch interceptor Introduce a two-phase interceptor system for task and stream manipulation at the gRPC layer. Interceptors enable extensible processing pipelines for cross-cutting concerns like logging, authentication, and metrics without modifying core worker logic. InterceptorLike protocol defines the async generator contract: pre-dispatch task modification followed by response stream wrapping. The @interceptor decorator provides automatic global registration. WoolInterceptor bridges the Wool interceptor interface to gRPC's AsyncServerInterceptor. LocalWorker and WorkerProcess accept an optional interceptors list, falling back to globally registered interceptors when not specified. --- wool/src/wool/__init__.py | 7 +- wool/src/wool/runtime/routine/interceptor.py | 297 +++++++++++++++++++ wool/src/wool/runtime/worker/local.py | 13 + wool/src/wool/runtime/worker/process.py | 14 +- 4 files changed, 327 insertions(+), 4 deletions(-) create mode 100644 wool/src/wool/runtime/routine/interceptor.py diff --git a/wool/src/wool/__init__.py b/wool/src/wool/__init__.py index 1e0fab5..a4b73ff 100644 --- a/wool/src/wool/__init__.py +++ b/wool/src/wool/__init__.py @@ -22,6 +22,8 @@ from wool.runtime.loadbalancer.base import NoWorkersAvailable from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer from wool.runtime.resourcepool import ResourcePool +from wool.runtime.routine.interceptor import get_registered_interceptors +from wool.runtime.routine.interceptor import interceptor from wool.runtime.routine.task import Task from wool.runtime.routine.task import TaskEvent from wool.runtime.routine.task import TaskEventHandler @@ -57,12 +59,10 @@ ) __all__ = [ - # Connection "RpcError", "TransientRpcError", "UnexpectedResponse", "WorkerConnection", - # Context "RuntimeContext", # Load balancing "LoadBalancerContextLike", @@ -76,6 +76,8 @@ "TaskEventType", "TaskException", "current_task", + "get_registered_interceptors", + "interceptor", "routine", # Workers "LocalWorker", @@ -99,7 +101,6 @@ "LocalDiscovery", "PredicateFunction", "WorkerMetadata", - # Typing "Factory", ] diff --git a/wool/src/wool/runtime/routine/interceptor.py b/wool/src/wool/runtime/routine/interceptor.py new file mode 100644 index 0000000..cbd98d6 --- /dev/null +++ b/wool/src/wool/runtime/routine/interceptor.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable +from typing import Protocol + +import cloudpickle +import grpc.aio +from grpc_interceptor.server import AsyncServerInterceptor + +if TYPE_CHECKING: + from wool.runtime.routine.task import Task + +# Global registry for decorator-registered interceptors +_registered_interceptors: list[InterceptorLike] = [] + + +# public +class InterceptorLike(Protocol): + """Protocol defining the Wool interceptor interface. + + Interceptors are async generators that wrap task execution, allowing + modification of tasks before dispatch and manipulation of response + streams during execution. + + Interceptors execute in three phases: + + 1. **Pre-dispatch**: Code before the first ``yield`` runs before the + task is dispatched. The interceptor can yield a modified + :class:`Task` or ``None`` to use the original task. + + 2. **Stream processing**: The ``yield`` expression receives the response + stream as an :class:`AsyncIterator`. The interceptor returns an async + generator that wraps this stream, allowing events to be modified, + filtered, or injected. + + 3. **Cleanup**: Code after the ``return`` statement (in a ``finally`` + block) runs after stream completion or cancellation. + + **Basic logging interceptor:** + + .. code-block:: python + + async def log_interceptor(task: Task) -> AsyncGenerator: + print(f"Starting task: {task.id}") + + # Yield None to use unmodified task + response_stream = yield None + + # Wrap and yield events from the response stream + try: + async for event in response_stream: + print(f"Event: {event}") + yield event + finally: + print(f"Task complete: {task.id}") + + **Task modification interceptor:** + + .. code-block:: python + + async def rbac_interceptor(task: Task) -> AsyncGenerator: + # Validate permissions before dispatch + if not has_permission(task): + raise PermissionError("Unauthorized") + + # Modify task metadata + modified_task = task.with_metadata(user=current_user()) + response_stream = yield modified_task + + # Pass through all events + async for event in response_stream: + yield event + + **Stream filtering interceptor:** + + .. code-block:: python + + async def filter_interceptor(task: Task) -> AsyncGenerator: + response_stream = yield None + + async for event in response_stream: + # Filter out certain event types + if should_include(event): + yield event + + **Error handling:** + + Exceptions raised by interceptors propagate to the client and cancel + the stream. The gRPC call fails with the interceptor's exception. + Applications must handle errors appropriately: + + .. code-block:: python + + async def safe_interceptor(task: Task) -> AsyncGenerator: + try: + response_stream = yield None + async for event in response_stream: + yield event + except Exception as e: + # Log error, send alert, etc. + logger.error(f"Task failed: {e}") + # Re-raise to propagate to client + raise + + :param task: + The work task being dispatched. + :returns: + An async generator that yields the modified task (or None) and + returns an async iterator wrapping the response stream. + """ + + def __call__( + self, task: Task + ) -> AsyncGenerator[Task | None, AsyncIterator | None]: ... + + +# public +def interceptor(func: InterceptorLike) -> InterceptorLike: + """Register a Wool interceptor globally. + + Decorated interceptors are automatically applied to all workers that + don't specify explicit interceptors. Use this for cross-cutting + concerns like logging, metrics, or distributed tracing. + + **Usage:** + + .. code-block:: python + + from wool.runtime.routine.interceptor import interceptor + + + @interceptor + async def metrics_interceptor(task): + start_time = time.time() + response_stream = yield None + + try: + async for event in response_stream: + yield event + finally: + duration = time.time() - start_time + record_metric("task_duration", duration) + + Workers automatically include registered interceptors: + + .. code-block:: python + + # This worker uses metrics_interceptor automatically + worker = LocalWorker("my-worker") + + To use only explicit interceptors (ignoring registered ones): + + .. code-block:: python + + # Only use explicit interceptors, not registered ones + worker = LocalWorker("my-worker", interceptors=[custom_interceptor]) + + :param func: + The interceptor function to register. + :returns: + The original function, unchanged. + """ + _registered_interceptors.append(func) + return func + + +def get_registered_interceptors() -> list[InterceptorLike]: + """Get all globally registered interceptors. + + :returns: + List of interceptors registered with the :func:`@interceptor + ` decorator. + """ + return _registered_interceptors.copy() + + +class WoolInterceptor(AsyncServerInterceptor): + """Bridges Wool interceptors to gRPC's interceptor interface. + + Converts high-level Wool interceptor semantics (task modification, + stream wrapping) into gRPC's low-level interceptor protocol. Only + applies to ``dispatch`` RPC calls (unary-stream). + + This class is an implementation detail - users should work with + :class:`WoolInterceptor` functions and the :func:`@interceptor + ` decorator. + + :param interceptors: + Wool interceptor functions to apply. + """ + + def __init__(self, interceptors: list[InterceptorLike]): + self._interceptors = interceptors + + async def intercept( + self, + method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]], + request_or_iterator: Any, + context: grpc.aio.ServicerContext, + method_name: str, + ) -> Any: + """Intercept gRPC service calls. + + Only applies interceptors to ``dispatch`` calls. Other RPC methods + (like ``stop``) bypass interception. + + :param method: + The gRPC service method being called. + :param request_or_iterator: + The request object or request iterator. + :param context: + The gRPC servicer context. + :param method_name: + The name of the method being called (e.g., + "/wool.Worker/dispatch"). + :returns: + The response or response iterator from the method. + """ + # Exit early if no interceptors registered + if not self._interceptors: + return await method(request_or_iterator, context) + + # Only intercept dispatch calls + if not method_name.endswith("/dispatch"): + return await method(request_or_iterator, context) + + # Deserialize task from protobuf request + task: Task = cloudpickle.loads(request_or_iterator.task) + + # Apply all interceptors in order, keeping generators alive + modified_task = task + active_generators = [] + for interceptor_func in self._interceptors: + try: + # Start the interceptor generator + gen = interceptor_func(modified_task) + + # Advance to first yield - get potentially modified task + task_modification = await gen.asend(None) + + # Update task if interceptor returned a modified version + if task_modification is not None: + modified_task = task_modification + + # Store generator for stream wrapping phase + active_generators.append(gen) + + except StopAsyncIteration: + # Interceptor didn't yield - treat as passthrough + active_generators.append(None) + except Exception: + # Interceptor raised error - propagate to caller + raise + + # If task was modified, update the protobuf request + if modified_task is not task: + # Create new request with modified task + request_or_iterator.task = cloudpickle.dumps(modified_task) + + # Call the actual dispatch method + response_stream = await method(request_or_iterator, context) + + # Wrap response stream with interceptors (in reverse order) + for gen in reversed(active_generators): + # Skip interceptors that didn't create generators + if gen is None: + continue + + try: + # Send the response stream - generator will start yielding events + try: + first_event = await gen.asend(response_stream) + + # The generator is now yielding events - wrap it + async def _create_wrapped_stream( + generator: AsyncGenerator, + first: Any, + ) -> AsyncGenerator: + yield first + async for event in generator: + yield event + + response_stream = _create_wrapped_stream(gen, first_event) + except StopAsyncIteration: + # Generator finished without yielding - use original stream + pass + + except Exception: + # Stream wrapping failed - propagate to caller + raise + + return response_stream diff --git a/wool/src/wool/runtime/worker/local.py b/wool/src/wool/runtime/worker/local.py index e9ef10b..af69cdf 100644 --- a/wool/src/wool/runtime/worker/local.py +++ b/wool/src/wool/runtime/worker/local.py @@ -9,6 +9,8 @@ import wool from wool.runtime import protobuf as pb from wool.runtime.discovery.base import WorkerMetadata +from wool.runtime.routine.interceptor import InterceptorLike +from wool.runtime.routine.interceptor import get_registered_interceptors from wool.runtime.worker.auth import WorkerCredentials from wool.runtime.worker.base import ChannelCredentialsType from wool.runtime.worker.base import ServerCredentialsType @@ -62,6 +64,11 @@ class LocalWorker(Worker): credentials for mutual TLS. Enables secure worker-to-worker communication. - ``None``: Worker uses insecure connections. + :param interceptors: + Optional list of :class:`WoolInterceptor` functions to apply to + this worker. If ``None``, uses globally registered interceptors + from the :func:`@interceptor ` + decorator. Pass an empty list to disable all interceptors. :param extra: Additional metadata as key-value pairs. """ @@ -78,6 +85,7 @@ def __init__( shutdown_grace_period: float = 60.0, proxy_pool_ttl: float = 60.0, credentials: WorkerCredentials | None = None, + interceptors: list[InterceptorLike] | None = None, **extra: Any, ): super().__init__(*tags, **extra) @@ -90,12 +98,17 @@ def __init__( self._server_credentials = None self._client_credentials = None + # Use provided interceptors or fall back to registered ones + if interceptors is None: + interceptors = get_registered_interceptors() + self._worker_process = WorkerProcess( host=host, port=port, shutdown_grace_period=shutdown_grace_period, proxy_pool_ttl=proxy_pool_ttl, server_credentials=self._server_credentials, + interceptors=interceptors, ) @property diff --git a/wool/src/wool/runtime/worker/process.py b/wool/src/wool/runtime/worker/process.py index 64e54bd..3030d86 100644 --- a/wool/src/wool/runtime/worker/process.py +++ b/wool/src/wool/runtime/worker/process.py @@ -16,6 +16,8 @@ import wool from wool.runtime import protobuf as pb from wool.runtime.resourcepool import ResourcePool +from wool.runtime.routine.interceptor import InterceptorLike +from wool.runtime.routine.interceptor import WoolInterceptor from wool.runtime.worker.base import ServerCredentialsType from wool.runtime.worker.base import resolve_server_credentials from wool.runtime.worker.service import WorkerService @@ -45,6 +47,8 @@ class WorkerProcess(Process): Proxy pool TTL in seconds. :param server_credentials: Optional gRPC server credentials for TLS/mTLS. + :param interceptors: + List of Wool interceptor functions to apply to task dispatch. :param args: Additional args for :class:`multiprocessing.Process`. :param kwargs: @@ -57,6 +61,7 @@ class WorkerProcess(Process): _shutdown_grace_period: float _proxy_pool_ttl: float _credentials: ServerCredentialsType + _interceptors: list[InterceptorLike] def __init__( self, @@ -66,6 +71,7 @@ def __init__( shutdown_grace_period: float = 60.0, proxy_pool_ttl: float = 60.0, server_credentials: ServerCredentialsType = None, + interceptors: list[InterceptorLike] | None = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -82,6 +88,7 @@ def __init__( raise ValueError("Proxy pool TTL must be positive") self._proxy_pool_ttl = proxy_pool_ttl self._credentials = server_credentials + self._interceptors = interceptors or [] self._get_port, self._set_port = Pipe(duplex=False) @property @@ -172,7 +179,12 @@ async def _serve(self): requests. It creates a gRPC server, adds the worker service, and starts listening for incoming connections. """ - server = grpc.aio.server() + # Create interceptor bridge if interceptors are registered + interceptors = [] + if self._interceptors: + interceptors.append(WoolInterceptor(self._interceptors)) + + server = grpc.aio.server(interceptors=interceptors) credentials = resolve_server_credentials(self._credentials) address = self._address(self._host, self._port) From 7306c77879ba477c9918c77092a675ed7f5b503f Mon Sep 17 00:00:00 2001 From: Conrad Date: Fri, 20 Feb 2026 21:25:22 -0500 Subject: [PATCH 3/3] test: Add tests for task dispatch interceptor Cover InterceptorLike protocol compliance, @interceptor decorator registration, WoolInterceptor gRPC bridge (dispatch-only filtering, task modification, stream wrapping, multi-interceptor chaining, error propagation), and updated public API surface assertions. --- .../tests/runtime/routine/test_interceptor.py | 1099 +++++++++++++++++ wool/tests/test_public.py | 5 +- 2 files changed, 1101 insertions(+), 3 deletions(-) create mode 100644 wool/tests/runtime/routine/test_interceptor.py diff --git a/wool/tests/runtime/routine/test_interceptor.py b/wool/tests/runtime/routine/test_interceptor.py new file mode 100644 index 0000000..18b75f9 --- /dev/null +++ b/wool/tests/runtime/routine/test_interceptor.py @@ -0,0 +1,1099 @@ +from __future__ import annotations + +import cloudpickle +import pytest +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from wool.runtime.routine.interceptor import WoolInterceptor +from wool.runtime.routine.interceptor import get_registered_interceptors +from wool.runtime.routine.interceptor import interceptor + + +@pytest.fixture(autouse=True) +def clear_interceptor_registry(): + """Clear the global interceptor registry before and after each test.""" + from wool.runtime.routine.interceptor import _registered_interceptors + + _registered_interceptors.clear() + yield + _registered_interceptors.clear() + + +@pytest.fixture +def sample_task(mocker): + """Provide a mock Task for testing.""" + task = mocker.MagicMock() + task.id = "test-task-123" + task.callable = mocker.MagicMock() + task.args = () + task.kwargs = {} + return task + + +@pytest.fixture +def mock_grpc_context(mocker): + """Provide a mock gRPC context.""" + context = mocker.MagicMock() + context.cancel = mocker.MagicMock() + return context + + +@pytest.fixture +def mock_request(sample_task, mocker): + """Provide a mock protobuf request with serialized task.""" + request = mocker.MagicMock() + request.task = cloudpickle.dumps(sample_task) + return request + + +async def create_mock_response_stream(*events): + """Create a mock async response stream.""" + for event in events: + yield event + + +def create_passthrough_interceptor(): + """Create a passthrough interceptor that yields None.""" + + async def passthrough(task): + response_stream = yield None + async for event in response_stream: + yield event + + return passthrough + + +def create_task_modifying_interceptor(modification_fn): + """Create an interceptor that modifies the task.""" + + async def modifier(task): + modified_task = modification_fn(task) + response_stream = yield modified_task + async for event in response_stream: + yield event + + return modifier + + +def create_stream_wrapping_interceptor(wrapper_fn): + """Create an interceptor that wraps the response stream.""" + + async def wrapper(task): + response_stream = yield None + async for event in response_stream: + yield wrapper_fn(event) + + return wrapper + + +def create_failing_interceptor(exception, fail_stage="pre-dispatch"): + """Create an interceptor that raises an exception.""" + + async def failing(task): + if fail_stage == "pre-dispatch": + raise exception + response_stream = yield None + if fail_stage == "stream-wrapping": + raise exception + async for event in response_stream: + yield event + + return failing + + +def create_order_tracking_interceptor(interceptor_id, order_list, phase="both"): + """Create an interceptor that tracks execution order.""" + + async def order_tracker(task): + if phase in ("both", "pre-dispatch"): + order_list.append(("pre", interceptor_id)) + response_stream = yield None + if phase in ("both", "stream-wrapping"): + order_list.append(("stream", interceptor_id)) + async for event in response_stream: + yield event + + return order_tracker + + +async def collect_stream_events(stream): + """Collect all events from an async iterator.""" + events = [] + async for event in stream: + events.append(event) + return events + + +@st.composite +def valid_passthrough_interceptor_list(draw): + """Generate a list of 0-5 passthrough interceptors.""" + count = draw(st.integers(min_value=0, max_value=5)) + return [create_passthrough_interceptor() for _ in range(count)] + + +@st.composite +def event_stream_strategy(draw): + """Generate lists of 0-100 hashable events.""" + return draw( + st.lists( + st.one_of(st.text(), st.integers(), st.tuples(st.text(), st.integers())), + min_size=0, + max_size=100, + ) + ) + + +class TestInterceptorDecorator: + """Tests for the @interceptor decorator.""" + + def test_registers_function_in_global_registry(self): + """Interceptor decorator registers function in global registry. + + Given: + An interceptor function + When: + Decorated with @interceptor + Then: + The function is added to the global registry + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + decorated = interceptor(my_interceptor) + + assert my_interceptor in get_registered_interceptors() + assert decorated is my_interceptor + + def test_returns_original_function_unchanged(self): + """Interceptor decorator returns original function unchanged. + + Given: + An interceptor function + When: + Decorated with @interceptor + Then: + The original function is returned unchanged + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + original_id = id(my_interceptor) + decorated = interceptor(my_interceptor) + + assert id(decorated) == original_id + assert decorated is my_interceptor + + def test_registers_multiple_functions_in_order(self): + """Interceptor decorator registers multiple functions in order. + + Given: + Multiple interceptor functions + When: + Each decorated with @interceptor + Then: + All functions are added to the registry in order + """ + + async def interceptor1(task): + response_stream = yield None + async for event in response_stream: + yield event + + async def interceptor2(task): + response_stream = yield None + async for event in response_stream: + yield event + + async def interceptor3(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(interceptor1) + interceptor(interceptor2) + interceptor(interceptor3) + + registered = get_registered_interceptors() + assert len(registered) == 3 + assert registered[0] is interceptor1 + assert registered[1] is interceptor2 + assert registered[2] is interceptor3 + + def test_allows_duplicate_registration(self): + """Interceptor decorator allows duplicate registration. + + Given: + The same interceptor function + When: + Decorated with @interceptor twice + Then: + The function appears twice in the registry + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(my_interceptor) + interceptor(my_interceptor) + + registered = get_registered_interceptors() + assert len(registered) == 2 + assert registered[0] is my_interceptor + assert registered[1] is my_interceptor + + +class TestGetRegisteredInterceptors: + """Tests for get_registered_interceptors().""" + + def test_returns_empty_list_when_no_interceptors(self): + """Get registered interceptors returns empty list when none registered. + + Given: + No registered interceptors + When: + get_registered_interceptors() is called + Then: + An empty list is returned + """ + result = get_registered_interceptors() + + assert result == [] + assert isinstance(result, list) + + def test_returns_single_registered_interceptor(self): + """Get registered interceptors returns single registered interceptor. + + Given: + One registered interceptor + When: + get_registered_interceptors() is called + Then: + A list containing the interceptor is returned + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(my_interceptor) + + result = get_registered_interceptors() + assert len(result) == 1 + assert result[0] is my_interceptor + + def test_returns_multiple_interceptors_in_order(self): + """Get registered interceptors returns all in registration order. + + Given: + Multiple registered interceptors + When: + get_registered_interceptors() is called + Then: + A list with all interceptors in registration order is returned + """ + + async def interceptor1(task): + response_stream = yield None + async for event in response_stream: + yield event + + async def interceptor2(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(interceptor1) + interceptor(interceptor2) + + result = get_registered_interceptors() + assert len(result) == 2 + assert result[0] is interceptor1 + assert result[1] is interceptor2 + + def test_returns_copy_not_reference(self): + """Get registered interceptors returns copy not reference. + + Given: + Registered interceptors exist + When: + get_registered_interceptors() is called and list is modified + Then: + The original registry is not affected + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(my_interceptor) + + result1 = get_registered_interceptors() + result1.append("fake") # type: ignore + + result2 = get_registered_interceptors() + assert len(result2) == 1 + assert result2[0] is my_interceptor + + +class TestWoolInterceptor: + """Tests for the WoolInterceptorBridge class.""" + + # ------------------------------------------------------------------------ + # Instantiation Tests + # ------------------------------------------------------------------------ + + def test_stores_interceptors_on_instantiation(self): + """WoolInterceptorBridge stores interceptors on instantiation. + + Given: + A list of interceptor functions + When: + WoolInterceptorBridge is instantiated + Then: + The bridge stores the interceptors + """ + interceptor1 = create_passthrough_interceptor() + interceptor2 = create_passthrough_interceptor() + + bridge = WoolInterceptor([interceptor1, interceptor2]) + + assert bridge._interceptors == [interceptor1, interceptor2] + + def test_instantiates_with_empty_list(self): + """WoolInterceptorBridge instantiates with empty list. + + Given: + An empty interceptor list + When: + WoolInterceptorBridge is instantiated + Then: + The bridge is created successfully + """ + bridge = WoolInterceptor([]) + + assert bridge._interceptors == [] + + # ------------------------------------------------------------------------ + # Early Exit and Bypass Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_exits_early_with_no_interceptors(self, mock_grpc_context, mocker): + """WoolInterceptorBridge exits early with no interceptors. + + Given: + A bridge with no interceptors + When: + intercept() is called for a dispatch method + Then: + The method is called directly without processing + """ + bridge = WoolInterceptor([]) + + mock_method = mocker.AsyncMock(return_value="direct_result") + mock_request = mocker.MagicMock() + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + assert result == "direct_result" + mock_method.assert_called_once_with(mock_request, mock_grpc_context) + mock_grpc_context.cancel.assert_not_called() + + @pytest.mark.asyncio + async def test_bypasses_non_dispatch_methods(self, mock_grpc_context, mocker): + """WoolInterceptorBridge bypasses non-dispatch methods. + + Given: + A bridge with interceptors + When: + intercept() is called for a non-dispatch method + Then: + The method is called directly without interceptors + """ + + async def failing_interceptor(task): + raise RuntimeError("Should not be called") + + bridge = WoolInterceptor([failing_interceptor]) + mock_method = mocker.AsyncMock(return_value="stop_result") + mock_request = mocker.MagicMock() + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/stop" + ) + + assert result == "stop_result" + mock_method.assert_called_once_with(mock_request, mock_grpc_context) + + # ------------------------------------------------------------------------ + # Task Modification Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_passthrough_interceptor_yields_none( + self, sample_task, mock_request, mock_grpc_context + ): + """Passthrough interceptor yields None dispatches task unmodified. + + Given: + A bridge with a passthrough interceptor that yields None + When: + intercept() is called for a dispatch method + Then: + The task is dispatched unmodified + """ + passthrough = create_passthrough_interceptor() + bridge = WoolInterceptor([passthrough]) + + async def mock_method(req, ctx): + # Verify task wasn't modified + task = cloudpickle.loads(req.task) + assert task.id == sample_task.id + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event1", "event2"] + + @pytest.mark.asyncio + async def test_task_modification_via_yield( + self, sample_task, mock_request, mock_grpc_context, mocker + ): + """Task modification via yield serializes modified task. + + Given: + A bridge with an interceptor that yields a modified task + When: + intercept() is called for a dispatch method + Then: + The modified task is serialized and dispatched + """ + + def modify_task(task): + modified = mocker.MagicMock() + modified.id = "modified-task-456" + modified.callable = task.callable + modified.args = task.args + modified.kwargs = task.kwargs + return modified + + modifier = create_task_modifying_interceptor(modify_task) + bridge = WoolInterceptor([modifier]) + + async def mock_method(req, ctx): + task = cloudpickle.loads(req.task) + assert task.id == "modified-task-456" + return create_mock_response_stream("event1") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event1"] + + @pytest.mark.asyncio + async def test_multiple_passthrough_interceptors( + self, sample_task, mock_request, mock_grpc_context + ): + """Multiple passthrough interceptors process task in order. + + Given: + A bridge with multiple interceptors that yield None + When: + intercept() is called for a dispatch method + Then: + All interceptors process the task in order + """ + order = [] + + def create_tracking_passthrough(idx): + async def tracker(task): + order.append(f"pre-{idx}") + response_stream = yield None + order.append(f"stream-{idx}") + async for event in response_stream: + yield event + + return tracker + + bridge = WoolInterceptor( + [ + create_tracking_passthrough(1), + create_tracking_passthrough(2), + create_tracking_passthrough(3), + ] + ) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + await collect_stream_events(result) + + # Pre-dispatch: forward order, stream-wrapping: reverse order + assert order == [ + "pre-1", + "pre-2", + "pre-3", + "stream-3", + "stream-2", + "stream-1", + ] + + @pytest.mark.asyncio + async def test_chained_task_modification( + self, sample_task, mock_request, mock_grpc_context, mocker + ): + """Chained task modification passes through modifications. + + Given: + A bridge with multiple task-modifying interceptors + When: + intercept() is called for a dispatch method + Then: + Each interceptor receives task modified by previous ones + """ + modifications = [] + + def create_modifier(suffix): + def modify(task): + modified = mocker.MagicMock() + modified.id = task.id + f"-{suffix}" + modified.callable = task.callable + modified.args = task.args + modified.kwargs = task.kwargs + modifications.append(modified.id) + return modified + + return create_task_modifying_interceptor(modify) + + bridge = WoolInterceptor([create_modifier("A"), create_modifier("B")]) + + async def mock_method(req, ctx): + task = cloudpickle.loads(req.task) + assert task.id == "test-task-123-A-B" + return create_mock_response_stream("event") + + await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + assert modifications == ["test-task-123-A", "test-task-123-A-B"] + + # ------------------------------------------------------------------------ + # Stream Wrapping Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_stream_wrapping(self, sample_task, mock_request, mock_grpc_context): + """Stream wrapping returns events from wrapped stream. + + Given: + A bridge with an interceptor that returns a wrapped stream + When: + intercept() is called for a dispatch method + Then: + Events from the wrapped stream are returned + """ + wrapper = create_stream_wrapping_interceptor(lambda event: f"wrapped-{event}") + bridge = WoolInterceptor([wrapper]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["wrapped-event1", "wrapped-event2"] + + @pytest.mark.asyncio + async def test_stream_filtering(self, sample_task, mock_request, mock_grpc_context): + """Stream filtering returns only filtered events. + + Given: + A bridge with an interceptor that filters stream events + When: + intercept() is called for a dispatch method + Then: + Only filtered events are returned to the client + """ + + async def filtering_interceptor(task): + response_stream = yield None + async for event in response_stream: + if "keep" in event: + yield event + + bridge = WoolInterceptor([filtering_interceptor]) + + async def mock_method(req, ctx): + return create_mock_response_stream("keep1", "drop", "keep2", "drop") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["keep1", "keep2"] + + @pytest.mark.asyncio + async def test_stream_event_injection( + self, sample_task, mock_request, mock_grpc_context + ): + """Stream event injection returns original and injected events. + + Given: + A bridge with an interceptor that injects additional events + When: + intercept() is called for a dispatch method + Then: + Both original and injected events are returned + """ + + async def injecting_interceptor(task): + response_stream = yield None + yield "injected-start" + async for event in response_stream: + yield event + yield "injected-end" + + bridge = WoolInterceptor([injecting_interceptor]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["injected-start", "event1", "event2", "injected-end"] + + @pytest.mark.asyncio + async def test_multiple_stream_wrappers_reverse_order( + self, sample_task, mock_request, mock_grpc_context + ): + """Multiple stream wrappers wrap stream in reverse order. + + Given: + A bridge with multiple stream-wrapping interceptors + When: + intercept() is called for a dispatch method + Then: + Interceptors wrap the stream in reverse order + """ + wrapper1 = create_stream_wrapping_interceptor(lambda e: f"[{e}]") + wrapper2 = create_stream_wrapping_interceptor(lambda e: f"<{e}>") + wrapper3 = create_stream_wrapping_interceptor(lambda e: f"{{{e}}}") + + bridge = WoolInterceptor([wrapper1, wrapper2, wrapper3]) + + async def mock_method(req, ctx): + return create_mock_response_stream("x") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + # Reverse order: wrapper3, wrapper2, wrapper1 + assert events == ["[<{x}>]"] + + # ------------------------------------------------------------------------ + # Error Handling Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_pre_dispatch_exception_propagation( + self, sample_task, mock_request, mock_grpc_context + ): + """Pre-dispatch exception propagates to caller. + + Given: + A bridge with interceptor that raises exception before yielding + When: + intercept() is called for a dispatch method + Then: + The exception propagates to the caller + """ + failing = create_failing_interceptor(ValueError("test error"), "pre-dispatch") + bridge = WoolInterceptor([failing]) + + async def mock_method(req, ctx): + return create_mock_response_stream("should-not-reach") + + with pytest.raises(ValueError, match="test error"): + await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + @pytest.mark.asyncio + async def test_pre_dispatch_stop_async_iteration( + self, sample_task, mock_request, mock_grpc_context + ): + """Pre-dispatch StopAsyncIteration treats as passthrough. + + Given: + A bridge with interceptor that raises StopAsyncIteration + When: + intercept() is called for a dispatch method + Then: + The interceptor is treated as a passthrough + """ + + async def stop_iteration_interceptor(task): + # Make this an async generator by using yield + if False: + yield + # This will cause StopAsyncIteration when asend(None) is called + + bridge = WoolInterceptor([stop_iteration_interceptor]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event"] + + @pytest.mark.asyncio + async def test_stream_wrapping_exception_propagation( + self, sample_task, mock_request, mock_grpc_context + ): + """Stream wrapping exception propagates to caller. + + Given: + A bridge with interceptor raising exception during wrapping + When: + intercept() is called for a dispatch method + Then: + The exception propagates to the caller + """ + failing = create_failing_interceptor( + RuntimeError("stream error"), "stream-wrapping" + ) + bridge = WoolInterceptor([failing]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + with pytest.raises(RuntimeError, match="stream error"): + await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + @pytest.mark.asyncio + async def test_stream_wrapping_stop_async_iteration( + self, sample_task, mock_request, mock_grpc_context + ): + """Stream wrapping StopAsyncIteration uses original stream. + + Given: + A bridge with interceptor raising StopAsyncIteration wrapping + When: + intercept() is called for a dispatch method + Then: + The original stream is used + """ + + async def stop_iteration_wrapper(task): + yield None + # Just return without yielding events - causes StopAsyncIteration + return + + bridge = WoolInterceptor([stop_iteration_wrapper]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event1", "event2"] + + @pytest.mark.asyncio + async def test_dispatch_method_exception_propagation( + self, sample_task, mock_request, mock_grpc_context + ): + """Dispatch method exception propagates to caller. + + Given: + A bridge with an interceptor + When: + The underlying dispatch method raises an exception + Then: + The exception propagates to the caller + """ + passthrough = create_passthrough_interceptor() + bridge = WoolInterceptor([passthrough]) + + async def failing_method(req, ctx): + raise RuntimeError("dispatch failed") + + with pytest.raises(RuntimeError, match="dispatch failed"): + await bridge.intercept( + failing_method, + mock_request, + mock_grpc_context, + "/wool.Worker/dispatch", + ) + + # ------------------------------------------------------------------------ + # Full Lifecycle Test + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_full_interceptor_lifecycle( + self, sample_task, mock_request, mock_grpc_context, mocker + ): + """Full interceptor lifecycle applies modifications then wrapping. + + Given: + A bridge with interceptors that modify tasks and wrap streams + When: + intercept() is called for a dispatch method + Then: + Task modifications apply before dispatch, stream wrapping after + """ + lifecycle_events = [] + + async def full_lifecycle_interceptor(task): + lifecycle_events.append("pre-dispatch") + + # Modify task + modified = mocker.MagicMock() + modified.id = "lifecycle-modified" + modified.callable = task.callable + modified.args = task.args + modified.kwargs = task.kwargs + + response_stream = yield modified + + lifecycle_events.append("stream-wrapping") + + async for event in response_stream: + yield f"wrapped-{event}" + + bridge = WoolInterceptor([full_lifecycle_interceptor]) + + async def mock_method(req, ctx): + task = cloudpickle.loads(req.task) + lifecycle_events.append(f"dispatch-{task.id}") + return create_mock_response_stream("event") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + + assert lifecycle_events == [ + "pre-dispatch", + "dispatch-lifecycle-modified", + "stream-wrapping", + ] + assert events == ["wrapped-event"] + + # ------------------------------------------------------------------------ + # Property-Based Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(interceptor_list=valid_passthrough_interceptor_list()) + async def test_passthrough_idempotency_property( + self, sample_task, mock_request, mock_grpc_context, interceptor_list + ): + """Property-based test: Passthrough interceptor idempotency. + + Given: + Any list of valid passthrough interceptors + When: + intercept() is called for dispatch + Then: + All events from original stream returned unchanged + """ + bridge = WoolInterceptor(interceptor_list) + + original_events = ["event1", "event2", "event3"] + + async def mock_method(req, ctx): + return create_mock_response_stream(*original_events) + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + output_events = await collect_stream_events(result) + assert output_events == original_events + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given( + events=event_stream_strategy(), + passthrough_count=st.integers(min_value=0, max_value=5), + ) + async def test_event_count_preservation_property( + self, sample_task, mock_request, mock_grpc_context, events, passthrough_count + ): + """Property-based test: Event count preservation. + + Given: + Any event stream (0-100 events) and 0-5 passthrough interceptors + When: + Events flow through the interceptor chain + Then: + Output event count exactly equals input event count + """ + interceptors = [ + create_passthrough_interceptor() for _ in range(passthrough_count) + ] + bridge = WoolInterceptor(interceptors) + + async def mock_method(req, ctx): + return create_mock_response_stream(*events) + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + output_events = await collect_stream_events(result) + assert len(output_events) == len(events) + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(interceptor_count=st.integers(min_value=1, max_value=10)) + async def test_interceptor_ordering_determinism_property( + self, sample_task, mock_request, mock_grpc_context, interceptor_count + ): + """Property-based test: Interceptor ordering determinism. + + Given: + Any Task and 1-10 order-tracking interceptors + When: + intercept() called for dispatch multiple times with same interceptors + Then: + Pre-dispatch forward order, stream-wrapping reverse order + """ + # Create order-tracking interceptors + order1 = [] + order2 = [] + + interceptors = [ + create_order_tracking_interceptor(i, order1) + for i in range(interceptor_count) + ] + + bridge = WoolInterceptor(interceptors) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + # First execution + result1 = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + await collect_stream_events(result1) + + # Second execution with new order tracking + interceptors2 = [ + create_order_tracking_interceptor(i, order2) + for i in range(interceptor_count) + ] + bridge2 = WoolInterceptor(interceptors2) + + result2 = await bridge2.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + await collect_stream_events(result2) + + # Verify deterministic ordering + assert order1 == order2 + + # Verify forward order for pre-dispatch + pre_dispatch_order = [item[1] for item in order1 if item[0] == "pre"] + assert pre_dispatch_order == list(range(interceptor_count)) + + # Verify reverse order for stream-wrapping + stream_order = [item[1] for item in order1 if item[0] == "stream"] + assert stream_order == list(reversed(range(interceptor_count))) + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given( + exception_type=st.sampled_from( + [ValueError, RuntimeError, PermissionError, TypeError, OSError] + ), + interceptor_position=st.integers(min_value=0, max_value=4), + fail_stage=st.sampled_from(["pre-dispatch", "stream-wrapping"]), + ) + async def test_error_propagation_universality_property( + self, + sample_task, + mock_request, + mock_grpc_context, + exception_type, + interceptor_position, + fail_stage, + ): + """Property-based test: Error propagation universality. + + Given: + Any exception type at any interceptor position in any phase + When: + intercept() is called for dispatch + Then: + Exception propagates to caller + """ + # Create interceptor chain with one failing interceptor + interceptors = [] + for i in range(5): + if i == interceptor_position: + interceptors.append( + create_failing_interceptor(exception_type("test error"), fail_stage) + ) + else: + interceptors.append(create_passthrough_interceptor()) + + bridge = WoolInterceptor(interceptors) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + # Verify exception propagates + with pytest.raises(exception_type, match="test error"): + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + # For stream-wrapping errors, we need to consume the stream + if fail_stage == "stream-wrapping": + await collect_stream_events(result) diff --git a/wool/tests/test_public.py b/wool/tests/test_public.py index 3050c52..48daa7b 100644 --- a/wool/tests/test_public.py +++ b/wool/tests/test_public.py @@ -31,12 +31,10 @@ def test_public_api_completeness(): """ # Arrange expected_public_api = [ - # Connection "RpcError", "TransientRpcError", "UnexpectedResponse", "WorkerConnection", - # Context "RuntimeContext", # Load balancing "LoadBalancerContextLike", @@ -50,6 +48,8 @@ def test_public_api_completeness(): "TaskEventType", "TaskException", "current_task", + "get_registered_interceptors", + "interceptor", "routine", # Workers "LocalWorker", @@ -73,7 +73,6 @@ def test_public_api_completeness(): "LocalDiscovery", "PredicateFunction", "WorkerMetadata", - # Typing "Factory", ]