Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4

with:
fetch-depth: 0
fetch-tags: true
- name: Install uv and prepare python
uses: astral-sh/setup-uv@v5
with:
Expand Down
18 changes: 13 additions & 5 deletions wool/protobuf/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ syntax = "proto3";
package wool.runtime.protobuf.task;

message Task {
string id = 1;
bytes callable = 2;
bytes args = 3;
bytes kwargs = 4;
string caller = 5;
string version = 1;
string id = 2;
bytes callable = 3;
bytes args = 4;
bytes kwargs = 5;
string caller = 6;
bytes proxy = 7;
string proxy_id = 8;
int32 timeout = 9;
Expand All @@ -17,6 +18,13 @@ message Task {
string tag = 13;
}

// Minimal envelope for pre-deserialization version extraction.
// Used by the version interceptor to parse field 1 from any
// Task wire format, including future incompatible versions.
message TaskVersionEnvelope {
string version = 1;
}

message Result {
bytes dump = 1;
}
Expand Down
1 change: 1 addition & 0 deletions wool/protobuf/worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ message Response {

message Ack {
// Acknowledgment that the task was received and processing started
string version = 1;
}

message Nack {
Expand Down
84 changes: 84 additions & 0 deletions wool/src/wool/runtime/protobuf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Wire protocol

Wool uses a binary wire protocol built on Protocol Buffers and gRPC
for all communication between clients and workers.

## Dispatch sequence

The `Worker.dispatch` RPC uses a server-streaming pattern. The client
sends a single `Task` message and receives a stream of `Response`
messages:

```
Client Worker
| |
|── Task ──────────────────────>|
| |
|<──────── Response(Ack) ───────| (or Nack on rejection)
|<──────── Response(Result) ────| (one or more results)
|<──────── Response(Exception) ─| (on failure)
| |
```
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a simple Mermaid sequence diagram.


### Response types

1. **Ack** — The worker accepted the task and started processing.
Carries the worker's `version` string for observability.
2. **Nack** — The worker rejected the task. The `reason` field
describes why (e.g., major version mismatch, unparseable version).
No further responses follow a Nack.
3. **Result** — A cloudpickle-serialized return value. Coroutine
tasks yield exactly one result; async generator tasks yield one
per iteration.
4. **Exception** — A cloudpickle-serialized exception from the
remote execution. Terminates the stream.

## Serialization

Wool uses a hybrid serialization approach:

- **Protobuf envelope** — Structured metadata fields (`id`,
`version`, `caller`, `timeout`, etc.) are native protobuf fields
for efficient parsing and forward compatibility.
- **cloudpickle payloads** — The `callable`, `args`, `kwargs`, and
`proxy` fields are serialized with cloudpickle and stored as
`bytes` fields. This allows arbitrary Python objects to be
transmitted without schema changes.
- **Results and exceptions** — `Result.dump` and `Exception.dump`
are cloudpickle-serialized bytes.

## Version compatibility

Wool enforces major-version compatibility at two layers.

### Discovery-time filtering

`WorkerProxy` applies a version filter during worker discovery.
Workers whose major version differs from the client's are excluded
from the load balancer and never receive tasks.

### Dispatch-time interception

`VersionInterceptor` is a gRPC server interceptor that extracts the
version field from raw request bytes *before* full deserialization.
This uses `TaskVersionEnvelope` — a minimal protobuf message
containing only `string version = 1` — which can parse field 1 from
any `Task` wire format, including future incompatible versions.

Requests with empty, missing, or unparseable version fields are
rejected with a `Nack` response. If the client's major version
differs from the worker's, the interceptor yields a `Nack` without
attempting full deserialization. This prevents deserialization errors
when the wire format has changed across major versions.

## Schema evolution rules

- **Additive-only within a major version.** New fields may be
appended with new field numbers. Existing field numbers and types
must not change within the same major version.
- **Major version = wire compatibility boundary.** A major version
bump permits breaking changes to the protobuf schema (field
renumbering, type changes, removal).
- **Field 1 is always `version`.** The `Task` message reserves
field 1 for the version string. This invariant enables
pre-deserialization version extraction via `TaskVersionEnvelope`.
3 changes: 2 additions & 1 deletion wool/src/wool/runtime/protobuf/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from wool.runtime.protobuf.task_pb2 import Exception
from wool.runtime.protobuf.task_pb2 import Result
from wool.runtime.protobuf.task_pb2 import Task
from wool.runtime.protobuf.task_pb2 import TaskVersionEnvelope
from wool.runtime.protobuf.task_pb2 import Worker as Worker
except ImportError as e:
from wool.runtime.protobuf.exception import ProtobufImportError

raise ProtobufImportError(e) from e

__all__ = ["Exception", "Result", "Task", "Worker"]
__all__ = ["Exception", "Result", "Task", "TaskVersionEnvelope", "Worker"]
1 change: 1 addition & 0 deletions wool/src/wool/runtime/routine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def from_protobuf(cls, task: pb.task.Task) -> Task:

def to_protobuf(self) -> pb.task.Task:
return pb.task.Task(
version=wool.__version__,
id=str(self.id),
callable=cloudpickle.dumps(self.callable),
args=cloudpickle.dumps(self.args),
Expand Down
4 changes: 4 additions & 0 deletions wool/src/wool/runtime/worker/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ async def _dispatch(self, ch, task, timeout):
call: _DispatchCall = ch.stub.dispatch(task.to_protobuf())
try:
response = await anext(aiter(call))
if response.HasField("nack"):
raise RpcError(
details=f"Task rejected by worker: {response.nack.reason}"
)
if not response.HasField("ack"):
raise UnexpectedResponse(
f"Expected 'ack' response, "
Expand Down
80 changes: 80 additions & 0 deletions wool/src/wool/runtime/worker/interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import grpc
import grpc.aio

import wool
from wool.runtime import protobuf as pb
from wool.runtime.worker.proxy import parse_major_version


class VersionInterceptor(grpc.aio.ServerInterceptor):
"""gRPC server interceptor for wire protocol version checking.

Intercepts the ``dispatch`` RPC to extract the client version from
field 1 of the raw request bytes using
:class:`~wool.runtime.protobuf.task.TaskVersionEnvelope`. If the
client major version differs from the local worker major version,
the RPC is short-circuited with a
:class:`~wool.runtime.protobuf.worker.Nack` response.

Requests with empty, missing, or unparseable version fields are
rejected.
"""

async def intercept_service(self, continuation, handler_call_details):
handler = await continuation(handler_call_details)
if handler is None or not handler_call_details.method.endswith("/dispatch"):
return handler

original_handler = handler.unary_stream
original_deserializer = handler.request_deserializer
assert original_handler is not None
assert original_deserializer is not None

async def version_checked_handler(request_bytes, context):
envelope = pb.task.TaskVersionEnvelope()
try:
envelope.ParseFromString(request_bytes)
except Exception:
yield pb.worker.Response(
nack=pb.worker.Nack(reason="Failed to parse version envelope")
)
return

client_major = parse_major_version(envelope.version)
local_major = parse_major_version(wool.__version__)

if client_major is None or local_major is None:
yield pb.worker.Response(
nack=pb.worker.Nack(
reason=(
f"Unparseable version: "
f"client={envelope.version!r}, "
f"worker={wool.__version__!r}"
)
)
)
return

if client_major != local_major:
yield pb.worker.Response(
nack=pb.worker.Nack(
reason=(
f"Major version mismatch: "
f"client={envelope.version}, "
f"worker={wool.__version__}"
)
)
)
return

request = original_deserializer(request_bytes)
async for response in original_handler(request, context): # pyright: ignore[reportGeneralTypeIssues]
yield response

return grpc.unary_stream_rpc_method_handler(
version_checked_handler,
request_deserializer=None,
response_serializer=handler.response_serializer,
)
3 changes: 2 additions & 1 deletion wool/src/wool/runtime/worker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from wool.runtime.resourcepool import ResourcePool
from wool.runtime.worker.base import ServerCredentialsType
from wool.runtime.worker.base import resolve_server_credentials
from wool.runtime.worker.interceptor import VersionInterceptor
from wool.runtime.worker.service import WorkerService

if TYPE_CHECKING:
Expand Down Expand Up @@ -172,7 +173,7 @@ async def _serve(self):
requests. It creates a gRPC server, adds the worker service, and
starts listening for incoming connections.
"""
server = grpc.aio.server()
server = grpc.aio.server(interceptors=[VersionInterceptor()])
credentials = resolve_server_credentials(self._credentials)
address = self._address(self._host, self._port)

Expand Down
47 changes: 43 additions & 4 deletions wool/src/wool/runtime/worker/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@
T = TypeVar("T")


def parse_major_version(version: str) -> int | None:
"""Extract the major version number from a version string.

:param version:
A version string (e.g. ``"1.2.3"``).
:returns:
The major version as an integer, or ``None`` if
unparseable.
"""
try:
return int(version.split(".")[0])
except (ValueError, IndexError):
return None


class ReducibleAsyncIterator(Generic[T]):
"""An async iterator that can be pickled via __reduce__.

Expand Down Expand Up @@ -208,21 +223,24 @@ def __init__(

# Create security filter based on resolved credentials
security_filter = self._create_security_filter(self._credentials)
version_filter = self._create_version_filter()

def compatible(w):
return security_filter(w) and version_filter(w)

match (pool_uri, discovery, workers):
case (pool_uri, None, None) if pool_uri is not None:
# Combine tag filter with security filter
# Combine tag and compatibility filters
def combined_filter(w):
return bool({pool_uri, *tags} & w.tags) and security_filter(w)
return bool({pool_uri, *tags} & w.tags) and compatible(w)

self._discovery = LocalDiscovery(pool_uri).subscribe(
filter=combined_filter
)
case (None, discovery, None) if discovery is not None:
self._discovery = discovery
case (None, None, workers) if workers is not None:
# Filter workers by security compatibility
compatible_workers = [w for w in workers if security_filter(w)]
compatible_workers = [w for w in workers if compatible(w)]
self._discovery = ReducibleAsyncIterator(
[
DiscoveryEvent("worker-added", metadata=w)
Expand Down Expand Up @@ -416,6 +434,27 @@ def _create_security_filter(
# Proxy has no credentials: only accept insecure workers
return lambda metadata: not metadata.secure

@staticmethod
def _create_version_filter() -> Callable[[WorkerMetadata], bool]:
"""Create discovery filter based on major version compatibility.

Workers must share the same major version as the local proxy.
Workers with unparseable versions are rejected.

:returns:
Predicate function for filtering workers by version
compatibility.
"""
local_major = parse_major_version(wool.__version__)

def version_filter(metadata: WorkerMetadata) -> bool:
worker_major = parse_major_version(metadata.version)
if local_major is None or worker_major is None:
return False
return local_major == worker_major

return version_filter

async def _await_workers(self):
while not self._loadbalancer_context or not self._loadbalancer_context.workers:
await asyncio.sleep(0)
Expand Down
2 changes: 1 addition & 1 deletion wool/src/wool/runtime/worker/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def dispatch(
)

with self._tracker(Task.from_protobuf(request)) as task:
yield pb.worker.Response(ack=pb.worker.Ack())
yield pb.worker.Response(ack=pb.worker.Ack(version=wool.__version__))
try:
if isasyncgen(task):
async for result in task:
Expand Down
Loading
Loading