Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions wool/protobuf/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ message Task {
string function = 11;
int32 line_no = 12;
string tag = 13;
optional string namespace = 14;
}

message Result {
Expand Down
2 changes: 1 addition & 1 deletion wool/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]
dependencies = [
"cloudpickle",
"grpcio>=1.76.0",
"grpcio>=1.78.0",
"portalocker",
"protobuf",
"shortuuid",
Expand Down
2 changes: 2 additions & 0 deletions wool/src/wool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from wool.runtime.loadbalancer.base import NoWorkersAvailable
from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer
from wool.runtime.resourcepool import ResourcePool
from wool.runtime.routine.task import WORKER
from wool.runtime.routine.task import Task
from wool.runtime.routine.task import TaskEvent
from wool.runtime.routine.task import TaskEventHandler
Expand Down Expand Up @@ -70,6 +71,7 @@
"NoWorkersAvailable",
"RoundRobinLoadBalancer",
# Routines
"WORKER",
"Task",
"TaskEvent",
"TaskEventHandler",
Expand Down
175 changes: 154 additions & 21 deletions wool/src/wool/runtime/routine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import asyncio
import logging
import traceback
import types
from collections.abc import AsyncIterator
from collections.abc import Callable
from contextlib import asynccontextmanager
from contextlib import contextmanager
from contextvars import Context
from contextvars import ContextVar
Expand Down Expand Up @@ -31,6 +34,35 @@
import wool
from wool.runtime import protobuf as pb
from wool.runtime.event import Event
from wool.runtime.resourcepool import ResourcePool

# Sentinel for worker-level globals
WORKER: str = "__worker__"

# Default TTL for namespace cleanup (5 minutes)
NAMESPACE_TTL: float = 300.0


def _create_namespace_globals(namespace: str) -> _IsolatedGlobals:
"""Factory function to create _IsolatedGlobals for a namespace.

Creates an _IsolatedGlobals that falls back to an empty dict for reads.
The actual callable's __globals__ aren't available at creation time,
so reads of module-level names won't work. Named namespaces are primarily
useful for sharing mutable state between tasks, not for accessing
module-level imports (those should be accessed via the callable's closure
or passed as arguments).
"""
return _IsolatedGlobals({})


# Registry stores shared _IsolatedGlobals instances, keyed by namespace name.
# All tasks using the same namespace share the same _IsolatedGlobals instance,
# so globals set by one task are visible to others.
_namespace_registry: ResourcePool[_IsolatedGlobals] = ResourcePool(
factory=_create_namespace_globals,
ttl=NAMESPACE_TTL,
)

Args = Tuple
Kwargs = Dict
Expand All @@ -40,6 +72,50 @@
W = TypeVar("W", bound=Routine)


class _IsolatedGlobals(dict):
"""A dict subclass that provides overlay semantics for function globals.

Writes go to this dict directly, while reads fall through to the
original globals if not found. This enables namespace isolation for
task execution.

For named namespaces, the same _IsolatedGlobals instance is shared
across all tasks using that namespace, so writes are visible to all.

.. note::
This type must be a ``dict`` subclass, not a ``MutableMapping``.
Python's ``STORE_GLOBAL`` bytecode uses ``PyDict_SetItem`` at the
C level, which requires an actual ``dict`` instance. Using a
``MutableMapping`` (like ``ChainMap``) as a function's
``__globals__`` would bypass ``__setitem__`` and fail.

:param original_globals:
The original function's __globals__ dict to fall back to for reads.
"""

def __init__(self, original_globals: dict):
super().__init__()
self._original = original_globals

def __getitem__(self, key):
# First check local overlay
try:
return super().__getitem__(key)
except KeyError:
pass
# Then check original globals
return self._original[key]

def __contains__(self, key):
return super().__contains__(key) or key in self._original

def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default


_do_dispatch: ContextVar[bool] = ContextVar("_do_dispatch", default=True)


Expand Down Expand Up @@ -118,6 +194,17 @@ class Task(Generic[W]):
Line number where the task was defined.
:param tag:
Optional descriptive tag for the task.
:param namespace:
Controls global namespace isolation for task execution:

- ``None`` (default): Ephemeral isolated globals. Each invocation
gets a fresh isolated globals dict, preventing global state
leakage between task invocations.
- ``"name"``: Named namespace. Tasks with the same namespace string
share globals, enabling patterns like ``@lru_cache`` across
related tasks.
- ``wool.WORKER``: Worker-level globals. The task runs in the
shared worker namespace with no isolation.
"""

id: UUID
Expand All @@ -132,6 +219,7 @@ class Task(Generic[W]):
function: str | None = None
line_no: int | None = None
tag: str | None = None
namespace: str | None = None

def __post_init__(self, **kwargs):
"""
Expand Down Expand Up @@ -216,6 +304,7 @@ def from_protobuf(cls, task: pb.task.Task) -> Task:
function=task.function if task.function else None,
line_no=task.line_no if task.line_no else None,
tag=task.tag if task.tag else None,
namespace=task.namespace if task.namespace else None,
)

def to_protobuf(self) -> pb.task.Task:
Expand All @@ -232,6 +321,7 @@ def to_protobuf(self) -> pb.task.Task:
function=self.function if self.function else "",
line_no=self.line_no if self.line_no else 0,
tag=self.tag if self.tag else "",
namespace=self.namespace if self.namespace else "",
)

def dispatch(self) -> W:
Expand All @@ -242,14 +332,57 @@ def dispatch(self) -> W:
else:
raise ValueError("Expected routine to be coroutine or async generator")

@asynccontextmanager
async def _prepare_callable(self) -> AsyncIterator[Callable[..., W]]:
"""Prepare the callable with appropriate globals based on namespace.

Yields the callable to execute, managing namespace lifecycle for
shared namespaces.

:yields:
The callable configured with the appropriate globals dict.
"""
if self.namespace is None:
# Ephemeral: fresh isolated globals each invocation
callable_globals = _IsolatedGlobals(self.callable.__globals__)
yield types.FunctionType(
self.callable.__code__,
callable_globals,
self.callable.__name__,
self.callable.__defaults__,
self.callable.__closure__,
)
elif self.namespace == WORKER:
# Worker globals: use original directly
yield self.callable
else:
# Shared namespace: use shared globals from pool.
# All tasks with the same namespace share the same
# _IsolatedGlobals instance, so STORE_GLOBAL writes
# are visible across tasks.
async with _namespace_registry.get(self.namespace) as namespace:
# Merge this callable's globals into the shared namespace.
# Keys already in namespace take precedence, allowing
# imports from different modules to accumulate.
for key, value in self.callable.__globals__.items():
if key not in namespace:
namespace[key] = value
yield types.FunctionType(
self.callable.__code__,
namespace,
self.callable.__name__,
self.callable.__defaults__,
self.callable.__closure__,
)

async def _run(self):
"""
Execute the task's callable with its arguments in proxy context.

:returns:
The result of executing the callable.
:raises RuntimeError:
If no proxy pool is available for task execution.
If no proxy pool available for task execution.
"""
proxy_pool = wool.__proxy_pool__.get()
if not proxy_pool:
Expand All @@ -260,8 +393,9 @@ async def _run(self):
try:
with self:
with do_dispatch(False):
await asyncio.sleep(0)
return await self.callable(*self.args, **self.kwargs)
await asyncio.sleep(0) # Release the event loop
async with self._prepare_callable() as callable_to_run:
return await callable_to_run(*self.args, **self.kwargs)
finally:
wool.__proxy__.reset(token)

Expand All @@ -279,24 +413,23 @@ async def _stream(self):
raise RuntimeError("No proxy pool available for task execution")
async with proxy_pool.get(self.proxy) as proxy:
await asyncio.sleep(0)
gen = self.callable(*self.args, **self.kwargs)
try:
while True:
# Set the proxy in context variable for nested task dispatch
token = wool.__proxy__.set(proxy)
try:
with self:
with do_dispatch(False):
try:
result = await anext(gen)
except StopAsyncIteration:
break
finally:
wool.__proxy__.reset(token)

yield result
finally:
await gen.aclose()
async with self._prepare_callable() as callable_to_run:
gen = callable_to_run(*self.args, **self.kwargs)
try:
while True:
token = wool.__proxy__.set(proxy)
try:
with self:
with do_dispatch(False):
try:
result = await anext(gen)
except StopAsyncIteration:
break
finally:
wool.__proxy__.reset(token)
yield result
finally:
await gen.aclose()

def _finish(self, _):
TaskEvent("task-completed", task=self).emit()
Expand Down
Loading
Loading