diff --git a/wool/pyproject.toml b/wool/pyproject.toml index 341a508..dca3909 100644 --- a/wool/pyproject.toml +++ b/wool/pyproject.toml @@ -19,6 +19,7 @@ classifiers = [ ] dependencies = [ "cloudpickle", + "grpc-interceptor", "grpcio>=1.76.0", "portalocker", "protobuf", diff --git a/wool/src/wool/__init__.py b/wool/src/wool/__init__.py index 1e0fab5..a4b73ff 100644 --- a/wool/src/wool/__init__.py +++ b/wool/src/wool/__init__.py @@ -22,6 +22,8 @@ from wool.runtime.loadbalancer.base import NoWorkersAvailable from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer from wool.runtime.resourcepool import ResourcePool +from wool.runtime.routine.interceptor import get_registered_interceptors +from wool.runtime.routine.interceptor import interceptor from wool.runtime.routine.task import Task from wool.runtime.routine.task import TaskEvent from wool.runtime.routine.task import TaskEventHandler @@ -57,12 +59,10 @@ ) __all__ = [ - # Connection "RpcError", "TransientRpcError", "UnexpectedResponse", "WorkerConnection", - # Context "RuntimeContext", # Load balancing "LoadBalancerContextLike", @@ -76,6 +76,8 @@ "TaskEventType", "TaskException", "current_task", + "get_registered_interceptors", + "interceptor", "routine", # Workers "LocalWorker", @@ -99,7 +101,6 @@ "LocalDiscovery", "PredicateFunction", "WorkerMetadata", - # Typing "Factory", ] diff --git a/wool/src/wool/runtime/routine/interceptor.py b/wool/src/wool/runtime/routine/interceptor.py new file mode 100644 index 0000000..cbd98d6 --- /dev/null +++ b/wool/src/wool/runtime/routine/interceptor.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable +from typing import Protocol + +import cloudpickle +import grpc.aio +from grpc_interceptor.server import AsyncServerInterceptor + +if TYPE_CHECKING: + from wool.runtime.routine.task import Task + +# Global registry for decorator-registered interceptors +_registered_interceptors: list[InterceptorLike] = [] + + +# public +class InterceptorLike(Protocol): + """Protocol defining the Wool interceptor interface. + + Interceptors are async generators that wrap task execution, allowing + modification of tasks before dispatch and manipulation of response + streams during execution. + + Interceptors execute in three phases: + + 1. **Pre-dispatch**: Code before the first ``yield`` runs before the + task is dispatched. The interceptor can yield a modified + :class:`Task` or ``None`` to use the original task. + + 2. **Stream processing**: The ``yield`` expression receives the response + stream as an :class:`AsyncIterator`. The interceptor returns an async + generator that wraps this stream, allowing events to be modified, + filtered, or injected. + + 3. **Cleanup**: Code after the ``return`` statement (in a ``finally`` + block) runs after stream completion or cancellation. + + **Basic logging interceptor:** + + .. code-block:: python + + async def log_interceptor(task: Task) -> AsyncGenerator: + print(f"Starting task: {task.id}") + + # Yield None to use unmodified task + response_stream = yield None + + # Wrap and yield events from the response stream + try: + async for event in response_stream: + print(f"Event: {event}") + yield event + finally: + print(f"Task complete: {task.id}") + + **Task modification interceptor:** + + .. code-block:: python + + async def rbac_interceptor(task: Task) -> AsyncGenerator: + # Validate permissions before dispatch + if not has_permission(task): + raise PermissionError("Unauthorized") + + # Modify task metadata + modified_task = task.with_metadata(user=current_user()) + response_stream = yield modified_task + + # Pass through all events + async for event in response_stream: + yield event + + **Stream filtering interceptor:** + + .. code-block:: python + + async def filter_interceptor(task: Task) -> AsyncGenerator: + response_stream = yield None + + async for event in response_stream: + # Filter out certain event types + if should_include(event): + yield event + + **Error handling:** + + Exceptions raised by interceptors propagate to the client and cancel + the stream. The gRPC call fails with the interceptor's exception. + Applications must handle errors appropriately: + + .. code-block:: python + + async def safe_interceptor(task: Task) -> AsyncGenerator: + try: + response_stream = yield None + async for event in response_stream: + yield event + except Exception as e: + # Log error, send alert, etc. + logger.error(f"Task failed: {e}") + # Re-raise to propagate to client + raise + + :param task: + The work task being dispatched. + :returns: + An async generator that yields the modified task (or None) and + returns an async iterator wrapping the response stream. + """ + + def __call__( + self, task: Task + ) -> AsyncGenerator[Task | None, AsyncIterator | None]: ... + + +# public +def interceptor(func: InterceptorLike) -> InterceptorLike: + """Register a Wool interceptor globally. + + Decorated interceptors are automatically applied to all workers that + don't specify explicit interceptors. Use this for cross-cutting + concerns like logging, metrics, or distributed tracing. + + **Usage:** + + .. code-block:: python + + from wool.runtime.routine.interceptor import interceptor + + + @interceptor + async def metrics_interceptor(task): + start_time = time.time() + response_stream = yield None + + try: + async for event in response_stream: + yield event + finally: + duration = time.time() - start_time + record_metric("task_duration", duration) + + Workers automatically include registered interceptors: + + .. code-block:: python + + # This worker uses metrics_interceptor automatically + worker = LocalWorker("my-worker") + + To use only explicit interceptors (ignoring registered ones): + + .. code-block:: python + + # Only use explicit interceptors, not registered ones + worker = LocalWorker("my-worker", interceptors=[custom_interceptor]) + + :param func: + The interceptor function to register. + :returns: + The original function, unchanged. + """ + _registered_interceptors.append(func) + return func + + +def get_registered_interceptors() -> list[InterceptorLike]: + """Get all globally registered interceptors. + + :returns: + List of interceptors registered with the :func:`@interceptor + ` decorator. + """ + return _registered_interceptors.copy() + + +class WoolInterceptor(AsyncServerInterceptor): + """Bridges Wool interceptors to gRPC's interceptor interface. + + Converts high-level Wool interceptor semantics (task modification, + stream wrapping) into gRPC's low-level interceptor protocol. Only + applies to ``dispatch`` RPC calls (unary-stream). + + This class is an implementation detail - users should work with + :class:`WoolInterceptor` functions and the :func:`@interceptor + ` decorator. + + :param interceptors: + Wool interceptor functions to apply. + """ + + def __init__(self, interceptors: list[InterceptorLike]): + self._interceptors = interceptors + + async def intercept( + self, + method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]], + request_or_iterator: Any, + context: grpc.aio.ServicerContext, + method_name: str, + ) -> Any: + """Intercept gRPC service calls. + + Only applies interceptors to ``dispatch`` calls. Other RPC methods + (like ``stop``) bypass interception. + + :param method: + The gRPC service method being called. + :param request_or_iterator: + The request object or request iterator. + :param context: + The gRPC servicer context. + :param method_name: + The name of the method being called (e.g., + "/wool.Worker/dispatch"). + :returns: + The response or response iterator from the method. + """ + # Exit early if no interceptors registered + if not self._interceptors: + return await method(request_or_iterator, context) + + # Only intercept dispatch calls + if not method_name.endswith("/dispatch"): + return await method(request_or_iterator, context) + + # Deserialize task from protobuf request + task: Task = cloudpickle.loads(request_or_iterator.task) + + # Apply all interceptors in order, keeping generators alive + modified_task = task + active_generators = [] + for interceptor_func in self._interceptors: + try: + # Start the interceptor generator + gen = interceptor_func(modified_task) + + # Advance to first yield - get potentially modified task + task_modification = await gen.asend(None) + + # Update task if interceptor returned a modified version + if task_modification is not None: + modified_task = task_modification + + # Store generator for stream wrapping phase + active_generators.append(gen) + + except StopAsyncIteration: + # Interceptor didn't yield - treat as passthrough + active_generators.append(None) + except Exception: + # Interceptor raised error - propagate to caller + raise + + # If task was modified, update the protobuf request + if modified_task is not task: + # Create new request with modified task + request_or_iterator.task = cloudpickle.dumps(modified_task) + + # Call the actual dispatch method + response_stream = await method(request_or_iterator, context) + + # Wrap response stream with interceptors (in reverse order) + for gen in reversed(active_generators): + # Skip interceptors that didn't create generators + if gen is None: + continue + + try: + # Send the response stream - generator will start yielding events + try: + first_event = await gen.asend(response_stream) + + # The generator is now yielding events - wrap it + async def _create_wrapped_stream( + generator: AsyncGenerator, + first: Any, + ) -> AsyncGenerator: + yield first + async for event in generator: + yield event + + response_stream = _create_wrapped_stream(gen, first_event) + except StopAsyncIteration: + # Generator finished without yielding - use original stream + pass + + except Exception: + # Stream wrapping failed - propagate to caller + raise + + return response_stream diff --git a/wool/src/wool/runtime/worker/local.py b/wool/src/wool/runtime/worker/local.py index e9ef10b..af69cdf 100644 --- a/wool/src/wool/runtime/worker/local.py +++ b/wool/src/wool/runtime/worker/local.py @@ -9,6 +9,8 @@ import wool from wool.runtime import protobuf as pb from wool.runtime.discovery.base import WorkerMetadata +from wool.runtime.routine.interceptor import InterceptorLike +from wool.runtime.routine.interceptor import get_registered_interceptors from wool.runtime.worker.auth import WorkerCredentials from wool.runtime.worker.base import ChannelCredentialsType from wool.runtime.worker.base import ServerCredentialsType @@ -62,6 +64,11 @@ class LocalWorker(Worker): credentials for mutual TLS. Enables secure worker-to-worker communication. - ``None``: Worker uses insecure connections. + :param interceptors: + Optional list of :class:`WoolInterceptor` functions to apply to + this worker. If ``None``, uses globally registered interceptors + from the :func:`@interceptor ` + decorator. Pass an empty list to disable all interceptors. :param extra: Additional metadata as key-value pairs. """ @@ -78,6 +85,7 @@ def __init__( shutdown_grace_period: float = 60.0, proxy_pool_ttl: float = 60.0, credentials: WorkerCredentials | None = None, + interceptors: list[InterceptorLike] | None = None, **extra: Any, ): super().__init__(*tags, **extra) @@ -90,12 +98,17 @@ def __init__( self._server_credentials = None self._client_credentials = None + # Use provided interceptors or fall back to registered ones + if interceptors is None: + interceptors = get_registered_interceptors() + self._worker_process = WorkerProcess( host=host, port=port, shutdown_grace_period=shutdown_grace_period, proxy_pool_ttl=proxy_pool_ttl, server_credentials=self._server_credentials, + interceptors=interceptors, ) @property diff --git a/wool/src/wool/runtime/worker/process.py b/wool/src/wool/runtime/worker/process.py index 64e54bd..3030d86 100644 --- a/wool/src/wool/runtime/worker/process.py +++ b/wool/src/wool/runtime/worker/process.py @@ -16,6 +16,8 @@ import wool from wool.runtime import protobuf as pb from wool.runtime.resourcepool import ResourcePool +from wool.runtime.routine.interceptor import InterceptorLike +from wool.runtime.routine.interceptor import WoolInterceptor from wool.runtime.worker.base import ServerCredentialsType from wool.runtime.worker.base import resolve_server_credentials from wool.runtime.worker.service import WorkerService @@ -45,6 +47,8 @@ class WorkerProcess(Process): Proxy pool TTL in seconds. :param server_credentials: Optional gRPC server credentials for TLS/mTLS. + :param interceptors: + List of Wool interceptor functions to apply to task dispatch. :param args: Additional args for :class:`multiprocessing.Process`. :param kwargs: @@ -57,6 +61,7 @@ class WorkerProcess(Process): _shutdown_grace_period: float _proxy_pool_ttl: float _credentials: ServerCredentialsType + _interceptors: list[InterceptorLike] def __init__( self, @@ -66,6 +71,7 @@ def __init__( shutdown_grace_period: float = 60.0, proxy_pool_ttl: float = 60.0, server_credentials: ServerCredentialsType = None, + interceptors: list[InterceptorLike] | None = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -82,6 +88,7 @@ def __init__( raise ValueError("Proxy pool TTL must be positive") self._proxy_pool_ttl = proxy_pool_ttl self._credentials = server_credentials + self._interceptors = interceptors or [] self._get_port, self._set_port = Pipe(duplex=False) @property @@ -172,7 +179,12 @@ async def _serve(self): requests. It creates a gRPC server, adds the worker service, and starts listening for incoming connections. """ - server = grpc.aio.server() + # Create interceptor bridge if interceptors are registered + interceptors = [] + if self._interceptors: + interceptors.append(WoolInterceptor(self._interceptors)) + + server = grpc.aio.server(interceptors=interceptors) credentials = resolve_server_credentials(self._credentials) address = self._address(self._host, self._port) diff --git a/wool/tests/runtime/routine/test_interceptor.py b/wool/tests/runtime/routine/test_interceptor.py new file mode 100644 index 0000000..18b75f9 --- /dev/null +++ b/wool/tests/runtime/routine/test_interceptor.py @@ -0,0 +1,1099 @@ +from __future__ import annotations + +import cloudpickle +import pytest +from hypothesis import HealthCheck +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from wool.runtime.routine.interceptor import WoolInterceptor +from wool.runtime.routine.interceptor import get_registered_interceptors +from wool.runtime.routine.interceptor import interceptor + + +@pytest.fixture(autouse=True) +def clear_interceptor_registry(): + """Clear the global interceptor registry before and after each test.""" + from wool.runtime.routine.interceptor import _registered_interceptors + + _registered_interceptors.clear() + yield + _registered_interceptors.clear() + + +@pytest.fixture +def sample_task(mocker): + """Provide a mock Task for testing.""" + task = mocker.MagicMock() + task.id = "test-task-123" + task.callable = mocker.MagicMock() + task.args = () + task.kwargs = {} + return task + + +@pytest.fixture +def mock_grpc_context(mocker): + """Provide a mock gRPC context.""" + context = mocker.MagicMock() + context.cancel = mocker.MagicMock() + return context + + +@pytest.fixture +def mock_request(sample_task, mocker): + """Provide a mock protobuf request with serialized task.""" + request = mocker.MagicMock() + request.task = cloudpickle.dumps(sample_task) + return request + + +async def create_mock_response_stream(*events): + """Create a mock async response stream.""" + for event in events: + yield event + + +def create_passthrough_interceptor(): + """Create a passthrough interceptor that yields None.""" + + async def passthrough(task): + response_stream = yield None + async for event in response_stream: + yield event + + return passthrough + + +def create_task_modifying_interceptor(modification_fn): + """Create an interceptor that modifies the task.""" + + async def modifier(task): + modified_task = modification_fn(task) + response_stream = yield modified_task + async for event in response_stream: + yield event + + return modifier + + +def create_stream_wrapping_interceptor(wrapper_fn): + """Create an interceptor that wraps the response stream.""" + + async def wrapper(task): + response_stream = yield None + async for event in response_stream: + yield wrapper_fn(event) + + return wrapper + + +def create_failing_interceptor(exception, fail_stage="pre-dispatch"): + """Create an interceptor that raises an exception.""" + + async def failing(task): + if fail_stage == "pre-dispatch": + raise exception + response_stream = yield None + if fail_stage == "stream-wrapping": + raise exception + async for event in response_stream: + yield event + + return failing + + +def create_order_tracking_interceptor(interceptor_id, order_list, phase="both"): + """Create an interceptor that tracks execution order.""" + + async def order_tracker(task): + if phase in ("both", "pre-dispatch"): + order_list.append(("pre", interceptor_id)) + response_stream = yield None + if phase in ("both", "stream-wrapping"): + order_list.append(("stream", interceptor_id)) + async for event in response_stream: + yield event + + return order_tracker + + +async def collect_stream_events(stream): + """Collect all events from an async iterator.""" + events = [] + async for event in stream: + events.append(event) + return events + + +@st.composite +def valid_passthrough_interceptor_list(draw): + """Generate a list of 0-5 passthrough interceptors.""" + count = draw(st.integers(min_value=0, max_value=5)) + return [create_passthrough_interceptor() for _ in range(count)] + + +@st.composite +def event_stream_strategy(draw): + """Generate lists of 0-100 hashable events.""" + return draw( + st.lists( + st.one_of(st.text(), st.integers(), st.tuples(st.text(), st.integers())), + min_size=0, + max_size=100, + ) + ) + + +class TestInterceptorDecorator: + """Tests for the @interceptor decorator.""" + + def test_registers_function_in_global_registry(self): + """Interceptor decorator registers function in global registry. + + Given: + An interceptor function + When: + Decorated with @interceptor + Then: + The function is added to the global registry + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + decorated = interceptor(my_interceptor) + + assert my_interceptor in get_registered_interceptors() + assert decorated is my_interceptor + + def test_returns_original_function_unchanged(self): + """Interceptor decorator returns original function unchanged. + + Given: + An interceptor function + When: + Decorated with @interceptor + Then: + The original function is returned unchanged + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + original_id = id(my_interceptor) + decorated = interceptor(my_interceptor) + + assert id(decorated) == original_id + assert decorated is my_interceptor + + def test_registers_multiple_functions_in_order(self): + """Interceptor decorator registers multiple functions in order. + + Given: + Multiple interceptor functions + When: + Each decorated with @interceptor + Then: + All functions are added to the registry in order + """ + + async def interceptor1(task): + response_stream = yield None + async for event in response_stream: + yield event + + async def interceptor2(task): + response_stream = yield None + async for event in response_stream: + yield event + + async def interceptor3(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(interceptor1) + interceptor(interceptor2) + interceptor(interceptor3) + + registered = get_registered_interceptors() + assert len(registered) == 3 + assert registered[0] is interceptor1 + assert registered[1] is interceptor2 + assert registered[2] is interceptor3 + + def test_allows_duplicate_registration(self): + """Interceptor decorator allows duplicate registration. + + Given: + The same interceptor function + When: + Decorated with @interceptor twice + Then: + The function appears twice in the registry + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(my_interceptor) + interceptor(my_interceptor) + + registered = get_registered_interceptors() + assert len(registered) == 2 + assert registered[0] is my_interceptor + assert registered[1] is my_interceptor + + +class TestGetRegisteredInterceptors: + """Tests for get_registered_interceptors().""" + + def test_returns_empty_list_when_no_interceptors(self): + """Get registered interceptors returns empty list when none registered. + + Given: + No registered interceptors + When: + get_registered_interceptors() is called + Then: + An empty list is returned + """ + result = get_registered_interceptors() + + assert result == [] + assert isinstance(result, list) + + def test_returns_single_registered_interceptor(self): + """Get registered interceptors returns single registered interceptor. + + Given: + One registered interceptor + When: + get_registered_interceptors() is called + Then: + A list containing the interceptor is returned + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(my_interceptor) + + result = get_registered_interceptors() + assert len(result) == 1 + assert result[0] is my_interceptor + + def test_returns_multiple_interceptors_in_order(self): + """Get registered interceptors returns all in registration order. + + Given: + Multiple registered interceptors + When: + get_registered_interceptors() is called + Then: + A list with all interceptors in registration order is returned + """ + + async def interceptor1(task): + response_stream = yield None + async for event in response_stream: + yield event + + async def interceptor2(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(interceptor1) + interceptor(interceptor2) + + result = get_registered_interceptors() + assert len(result) == 2 + assert result[0] is interceptor1 + assert result[1] is interceptor2 + + def test_returns_copy_not_reference(self): + """Get registered interceptors returns copy not reference. + + Given: + Registered interceptors exist + When: + get_registered_interceptors() is called and list is modified + Then: + The original registry is not affected + """ + + async def my_interceptor(task): + response_stream = yield None + async for event in response_stream: + yield event + + interceptor(my_interceptor) + + result1 = get_registered_interceptors() + result1.append("fake") # type: ignore + + result2 = get_registered_interceptors() + assert len(result2) == 1 + assert result2[0] is my_interceptor + + +class TestWoolInterceptor: + """Tests for the WoolInterceptorBridge class.""" + + # ------------------------------------------------------------------------ + # Instantiation Tests + # ------------------------------------------------------------------------ + + def test_stores_interceptors_on_instantiation(self): + """WoolInterceptorBridge stores interceptors on instantiation. + + Given: + A list of interceptor functions + When: + WoolInterceptorBridge is instantiated + Then: + The bridge stores the interceptors + """ + interceptor1 = create_passthrough_interceptor() + interceptor2 = create_passthrough_interceptor() + + bridge = WoolInterceptor([interceptor1, interceptor2]) + + assert bridge._interceptors == [interceptor1, interceptor2] + + def test_instantiates_with_empty_list(self): + """WoolInterceptorBridge instantiates with empty list. + + Given: + An empty interceptor list + When: + WoolInterceptorBridge is instantiated + Then: + The bridge is created successfully + """ + bridge = WoolInterceptor([]) + + assert bridge._interceptors == [] + + # ------------------------------------------------------------------------ + # Early Exit and Bypass Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_exits_early_with_no_interceptors(self, mock_grpc_context, mocker): + """WoolInterceptorBridge exits early with no interceptors. + + Given: + A bridge with no interceptors + When: + intercept() is called for a dispatch method + Then: + The method is called directly without processing + """ + bridge = WoolInterceptor([]) + + mock_method = mocker.AsyncMock(return_value="direct_result") + mock_request = mocker.MagicMock() + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + assert result == "direct_result" + mock_method.assert_called_once_with(mock_request, mock_grpc_context) + mock_grpc_context.cancel.assert_not_called() + + @pytest.mark.asyncio + async def test_bypasses_non_dispatch_methods(self, mock_grpc_context, mocker): + """WoolInterceptorBridge bypasses non-dispatch methods. + + Given: + A bridge with interceptors + When: + intercept() is called for a non-dispatch method + Then: + The method is called directly without interceptors + """ + + async def failing_interceptor(task): + raise RuntimeError("Should not be called") + + bridge = WoolInterceptor([failing_interceptor]) + mock_method = mocker.AsyncMock(return_value="stop_result") + mock_request = mocker.MagicMock() + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/stop" + ) + + assert result == "stop_result" + mock_method.assert_called_once_with(mock_request, mock_grpc_context) + + # ------------------------------------------------------------------------ + # Task Modification Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_passthrough_interceptor_yields_none( + self, sample_task, mock_request, mock_grpc_context + ): + """Passthrough interceptor yields None dispatches task unmodified. + + Given: + A bridge with a passthrough interceptor that yields None + When: + intercept() is called for a dispatch method + Then: + The task is dispatched unmodified + """ + passthrough = create_passthrough_interceptor() + bridge = WoolInterceptor([passthrough]) + + async def mock_method(req, ctx): + # Verify task wasn't modified + task = cloudpickle.loads(req.task) + assert task.id == sample_task.id + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event1", "event2"] + + @pytest.mark.asyncio + async def test_task_modification_via_yield( + self, sample_task, mock_request, mock_grpc_context, mocker + ): + """Task modification via yield serializes modified task. + + Given: + A bridge with an interceptor that yields a modified task + When: + intercept() is called for a dispatch method + Then: + The modified task is serialized and dispatched + """ + + def modify_task(task): + modified = mocker.MagicMock() + modified.id = "modified-task-456" + modified.callable = task.callable + modified.args = task.args + modified.kwargs = task.kwargs + return modified + + modifier = create_task_modifying_interceptor(modify_task) + bridge = WoolInterceptor([modifier]) + + async def mock_method(req, ctx): + task = cloudpickle.loads(req.task) + assert task.id == "modified-task-456" + return create_mock_response_stream("event1") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event1"] + + @pytest.mark.asyncio + async def test_multiple_passthrough_interceptors( + self, sample_task, mock_request, mock_grpc_context + ): + """Multiple passthrough interceptors process task in order. + + Given: + A bridge with multiple interceptors that yield None + When: + intercept() is called for a dispatch method + Then: + All interceptors process the task in order + """ + order = [] + + def create_tracking_passthrough(idx): + async def tracker(task): + order.append(f"pre-{idx}") + response_stream = yield None + order.append(f"stream-{idx}") + async for event in response_stream: + yield event + + return tracker + + bridge = WoolInterceptor( + [ + create_tracking_passthrough(1), + create_tracking_passthrough(2), + create_tracking_passthrough(3), + ] + ) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + await collect_stream_events(result) + + # Pre-dispatch: forward order, stream-wrapping: reverse order + assert order == [ + "pre-1", + "pre-2", + "pre-3", + "stream-3", + "stream-2", + "stream-1", + ] + + @pytest.mark.asyncio + async def test_chained_task_modification( + self, sample_task, mock_request, mock_grpc_context, mocker + ): + """Chained task modification passes through modifications. + + Given: + A bridge with multiple task-modifying interceptors + When: + intercept() is called for a dispatch method + Then: + Each interceptor receives task modified by previous ones + """ + modifications = [] + + def create_modifier(suffix): + def modify(task): + modified = mocker.MagicMock() + modified.id = task.id + f"-{suffix}" + modified.callable = task.callable + modified.args = task.args + modified.kwargs = task.kwargs + modifications.append(modified.id) + return modified + + return create_task_modifying_interceptor(modify) + + bridge = WoolInterceptor([create_modifier("A"), create_modifier("B")]) + + async def mock_method(req, ctx): + task = cloudpickle.loads(req.task) + assert task.id == "test-task-123-A-B" + return create_mock_response_stream("event") + + await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + assert modifications == ["test-task-123-A", "test-task-123-A-B"] + + # ------------------------------------------------------------------------ + # Stream Wrapping Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_stream_wrapping(self, sample_task, mock_request, mock_grpc_context): + """Stream wrapping returns events from wrapped stream. + + Given: + A bridge with an interceptor that returns a wrapped stream + When: + intercept() is called for a dispatch method + Then: + Events from the wrapped stream are returned + """ + wrapper = create_stream_wrapping_interceptor(lambda event: f"wrapped-{event}") + bridge = WoolInterceptor([wrapper]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["wrapped-event1", "wrapped-event2"] + + @pytest.mark.asyncio + async def test_stream_filtering(self, sample_task, mock_request, mock_grpc_context): + """Stream filtering returns only filtered events. + + Given: + A bridge with an interceptor that filters stream events + When: + intercept() is called for a dispatch method + Then: + Only filtered events are returned to the client + """ + + async def filtering_interceptor(task): + response_stream = yield None + async for event in response_stream: + if "keep" in event: + yield event + + bridge = WoolInterceptor([filtering_interceptor]) + + async def mock_method(req, ctx): + return create_mock_response_stream("keep1", "drop", "keep2", "drop") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["keep1", "keep2"] + + @pytest.mark.asyncio + async def test_stream_event_injection( + self, sample_task, mock_request, mock_grpc_context + ): + """Stream event injection returns original and injected events. + + Given: + A bridge with an interceptor that injects additional events + When: + intercept() is called for a dispatch method + Then: + Both original and injected events are returned + """ + + async def injecting_interceptor(task): + response_stream = yield None + yield "injected-start" + async for event in response_stream: + yield event + yield "injected-end" + + bridge = WoolInterceptor([injecting_interceptor]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["injected-start", "event1", "event2", "injected-end"] + + @pytest.mark.asyncio + async def test_multiple_stream_wrappers_reverse_order( + self, sample_task, mock_request, mock_grpc_context + ): + """Multiple stream wrappers wrap stream in reverse order. + + Given: + A bridge with multiple stream-wrapping interceptors + When: + intercept() is called for a dispatch method + Then: + Interceptors wrap the stream in reverse order + """ + wrapper1 = create_stream_wrapping_interceptor(lambda e: f"[{e}]") + wrapper2 = create_stream_wrapping_interceptor(lambda e: f"<{e}>") + wrapper3 = create_stream_wrapping_interceptor(lambda e: f"{{{e}}}") + + bridge = WoolInterceptor([wrapper1, wrapper2, wrapper3]) + + async def mock_method(req, ctx): + return create_mock_response_stream("x") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + # Reverse order: wrapper3, wrapper2, wrapper1 + assert events == ["[<{x}>]"] + + # ------------------------------------------------------------------------ + # Error Handling Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_pre_dispatch_exception_propagation( + self, sample_task, mock_request, mock_grpc_context + ): + """Pre-dispatch exception propagates to caller. + + Given: + A bridge with interceptor that raises exception before yielding + When: + intercept() is called for a dispatch method + Then: + The exception propagates to the caller + """ + failing = create_failing_interceptor(ValueError("test error"), "pre-dispatch") + bridge = WoolInterceptor([failing]) + + async def mock_method(req, ctx): + return create_mock_response_stream("should-not-reach") + + with pytest.raises(ValueError, match="test error"): + await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + @pytest.mark.asyncio + async def test_pre_dispatch_stop_async_iteration( + self, sample_task, mock_request, mock_grpc_context + ): + """Pre-dispatch StopAsyncIteration treats as passthrough. + + Given: + A bridge with interceptor that raises StopAsyncIteration + When: + intercept() is called for a dispatch method + Then: + The interceptor is treated as a passthrough + """ + + async def stop_iteration_interceptor(task): + # Make this an async generator by using yield + if False: + yield + # This will cause StopAsyncIteration when asend(None) is called + + bridge = WoolInterceptor([stop_iteration_interceptor]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event"] + + @pytest.mark.asyncio + async def test_stream_wrapping_exception_propagation( + self, sample_task, mock_request, mock_grpc_context + ): + """Stream wrapping exception propagates to caller. + + Given: + A bridge with interceptor raising exception during wrapping + When: + intercept() is called for a dispatch method + Then: + The exception propagates to the caller + """ + failing = create_failing_interceptor( + RuntimeError("stream error"), "stream-wrapping" + ) + bridge = WoolInterceptor([failing]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + with pytest.raises(RuntimeError, match="stream error"): + await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + @pytest.mark.asyncio + async def test_stream_wrapping_stop_async_iteration( + self, sample_task, mock_request, mock_grpc_context + ): + """Stream wrapping StopAsyncIteration uses original stream. + + Given: + A bridge with interceptor raising StopAsyncIteration wrapping + When: + intercept() is called for a dispatch method + Then: + The original stream is used + """ + + async def stop_iteration_wrapper(task): + yield None + # Just return without yielding events - causes StopAsyncIteration + return + + bridge = WoolInterceptor([stop_iteration_wrapper]) + + async def mock_method(req, ctx): + return create_mock_response_stream("event1", "event2") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + assert events == ["event1", "event2"] + + @pytest.mark.asyncio + async def test_dispatch_method_exception_propagation( + self, sample_task, mock_request, mock_grpc_context + ): + """Dispatch method exception propagates to caller. + + Given: + A bridge with an interceptor + When: + The underlying dispatch method raises an exception + Then: + The exception propagates to the caller + """ + passthrough = create_passthrough_interceptor() + bridge = WoolInterceptor([passthrough]) + + async def failing_method(req, ctx): + raise RuntimeError("dispatch failed") + + with pytest.raises(RuntimeError, match="dispatch failed"): + await bridge.intercept( + failing_method, + mock_request, + mock_grpc_context, + "/wool.Worker/dispatch", + ) + + # ------------------------------------------------------------------------ + # Full Lifecycle Test + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + async def test_full_interceptor_lifecycle( + self, sample_task, mock_request, mock_grpc_context, mocker + ): + """Full interceptor lifecycle applies modifications then wrapping. + + Given: + A bridge with interceptors that modify tasks and wrap streams + When: + intercept() is called for a dispatch method + Then: + Task modifications apply before dispatch, stream wrapping after + """ + lifecycle_events = [] + + async def full_lifecycle_interceptor(task): + lifecycle_events.append("pre-dispatch") + + # Modify task + modified = mocker.MagicMock() + modified.id = "lifecycle-modified" + modified.callable = task.callable + modified.args = task.args + modified.kwargs = task.kwargs + + response_stream = yield modified + + lifecycle_events.append("stream-wrapping") + + async for event in response_stream: + yield f"wrapped-{event}" + + bridge = WoolInterceptor([full_lifecycle_interceptor]) + + async def mock_method(req, ctx): + task = cloudpickle.loads(req.task) + lifecycle_events.append(f"dispatch-{task.id}") + return create_mock_response_stream("event") + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + events = await collect_stream_events(result) + + assert lifecycle_events == [ + "pre-dispatch", + "dispatch-lifecycle-modified", + "stream-wrapping", + ] + assert events == ["wrapped-event"] + + # ------------------------------------------------------------------------ + # Property-Based Tests + # ------------------------------------------------------------------------ + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(interceptor_list=valid_passthrough_interceptor_list()) + async def test_passthrough_idempotency_property( + self, sample_task, mock_request, mock_grpc_context, interceptor_list + ): + """Property-based test: Passthrough interceptor idempotency. + + Given: + Any list of valid passthrough interceptors + When: + intercept() is called for dispatch + Then: + All events from original stream returned unchanged + """ + bridge = WoolInterceptor(interceptor_list) + + original_events = ["event1", "event2", "event3"] + + async def mock_method(req, ctx): + return create_mock_response_stream(*original_events) + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + output_events = await collect_stream_events(result) + assert output_events == original_events + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given( + events=event_stream_strategy(), + passthrough_count=st.integers(min_value=0, max_value=5), + ) + async def test_event_count_preservation_property( + self, sample_task, mock_request, mock_grpc_context, events, passthrough_count + ): + """Property-based test: Event count preservation. + + Given: + Any event stream (0-100 events) and 0-5 passthrough interceptors + When: + Events flow through the interceptor chain + Then: + Output event count exactly equals input event count + """ + interceptors = [ + create_passthrough_interceptor() for _ in range(passthrough_count) + ] + bridge = WoolInterceptor(interceptors) + + async def mock_method(req, ctx): + return create_mock_response_stream(*events) + + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + + output_events = await collect_stream_events(result) + assert len(output_events) == len(events) + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(interceptor_count=st.integers(min_value=1, max_value=10)) + async def test_interceptor_ordering_determinism_property( + self, sample_task, mock_request, mock_grpc_context, interceptor_count + ): + """Property-based test: Interceptor ordering determinism. + + Given: + Any Task and 1-10 order-tracking interceptors + When: + intercept() called for dispatch multiple times with same interceptors + Then: + Pre-dispatch forward order, stream-wrapping reverse order + """ + # Create order-tracking interceptors + order1 = [] + order2 = [] + + interceptors = [ + create_order_tracking_interceptor(i, order1) + for i in range(interceptor_count) + ] + + bridge = WoolInterceptor(interceptors) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + # First execution + result1 = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + await collect_stream_events(result1) + + # Second execution with new order tracking + interceptors2 = [ + create_order_tracking_interceptor(i, order2) + for i in range(interceptor_count) + ] + bridge2 = WoolInterceptor(interceptors2) + + result2 = await bridge2.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + await collect_stream_events(result2) + + # Verify deterministic ordering + assert order1 == order2 + + # Verify forward order for pre-dispatch + pre_dispatch_order = [item[1] for item in order1 if item[0] == "pre"] + assert pre_dispatch_order == list(range(interceptor_count)) + + # Verify reverse order for stream-wrapping + stream_order = [item[1] for item in order1 if item[0] == "stream"] + assert stream_order == list(reversed(range(interceptor_count))) + + @pytest.mark.asyncio + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given( + exception_type=st.sampled_from( + [ValueError, RuntimeError, PermissionError, TypeError, OSError] + ), + interceptor_position=st.integers(min_value=0, max_value=4), + fail_stage=st.sampled_from(["pre-dispatch", "stream-wrapping"]), + ) + async def test_error_propagation_universality_property( + self, + sample_task, + mock_request, + mock_grpc_context, + exception_type, + interceptor_position, + fail_stage, + ): + """Property-based test: Error propagation universality. + + Given: + Any exception type at any interceptor position in any phase + When: + intercept() is called for dispatch + Then: + Exception propagates to caller + """ + # Create interceptor chain with one failing interceptor + interceptors = [] + for i in range(5): + if i == interceptor_position: + interceptors.append( + create_failing_interceptor(exception_type("test error"), fail_stage) + ) + else: + interceptors.append(create_passthrough_interceptor()) + + bridge = WoolInterceptor(interceptors) + + async def mock_method(req, ctx): + return create_mock_response_stream("event") + + # Verify exception propagates + with pytest.raises(exception_type, match="test error"): + result = await bridge.intercept( + mock_method, mock_request, mock_grpc_context, "/wool.Worker/dispatch" + ) + # For stream-wrapping errors, we need to consume the stream + if fail_stage == "stream-wrapping": + await collect_stream_events(result) diff --git a/wool/tests/test_public.py b/wool/tests/test_public.py index 3050c52..48daa7b 100644 --- a/wool/tests/test_public.py +++ b/wool/tests/test_public.py @@ -31,12 +31,10 @@ def test_public_api_completeness(): """ # Arrange expected_public_api = [ - # Connection "RpcError", "TransientRpcError", "UnexpectedResponse", "WorkerConnection", - # Context "RuntimeContext", # Load balancing "LoadBalancerContextLike", @@ -50,6 +48,8 @@ def test_public_api_completeness(): "TaskEventType", "TaskException", "current_task", + "get_registered_interceptors", + "interceptor", "routine", # Workers "LocalWorker", @@ -73,7 +73,6 @@ def test_public_api_completeness(): "LocalDiscovery", "PredicateFunction", "WorkerMetadata", - # Typing "Factory", ]