Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,59 @@ on:
- release

jobs:
lint:
name: Lint / pyright
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv and prepare python
uses: astral-sh/setup-uv@v5
with:
python-version: '3.13'

- name: Install packages
run: |
.github/scripts/install-python-packages.sh
uv pip install -e './wool[dev]'

- name: Run pyright
run: |
cd wool
uv run pyright

run-tests:
name: Namespace ${{ matrix.namespace }} / Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
strategy:
matrix:
namespace:
- 'wool'
python-version:
python-version:
- '3.11'
- '3.12'
- '3.13'
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv and prepare python
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install packages
run: .github/scripts/install-python-packages.sh
- name: Run tests
env:
NAMESPACE: ${{ matrix.namespace }}
run: |
.github/scripts/install-python-packages.sh
uv pip install -e './${{ env.NAMESPACE }}[dev]'
uv pip freeze

- name: Run tests
env:
NAMESPACE: ${{ matrix.namespace }}
run: |
cd ${{ env.NAMESPACE }}
uv run pytest
15 changes: 15 additions & 0 deletions wool/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dev = [
"cryptography",
"debugpy",
"hypothesis",
"pyright",
"pytest",
"pytest-asyncio",
"pytest-cov",
Expand Down Expand Up @@ -96,3 +97,17 @@ docstring-code-format = true
combine-as-imports = false
force-single-line = true
known-first-party = ["wool"]

[tool.pyright]
venvPath = "."
venv = ".venv"
include = ["src"]
exclude = [
"**/__pycache__",
"build",
"dist",
"tests/*",
"src/wool/runtime/protobuf/*pb2*",
]
reportMissingImports = true
reportMissingTypeStubs = false
4 changes: 2 additions & 2 deletions wool/src/wool/runtime/loadbalancer/roundrobin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from asyncio import Lock
from typing import TYPE_CHECKING
from typing import AsyncIterator
from typing import AsyncGenerator
from typing import Final

from wool.runtime.worker.connection import TransientRpcError
Expand Down Expand Up @@ -43,7 +43,7 @@ async def dispatch(
*,
context: LoadBalancerContextLike,
timeout: float | None = None,
) -> AsyncIterator:
) -> AsyncGenerator:
"""Dispatch a task to the next available worker.

:param task:
Expand Down
8 changes: 6 additions & 2 deletions wool/src/wool/runtime/routine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from inspect import isasyncgenfunction
from inspect import iscoroutinefunction
from types import TracebackType
from typing import Any
from typing import AsyncGenerator
from typing import ContextManager
from typing import Coroutine
Expand All @@ -23,6 +24,7 @@
from typing import Tuple
from typing import TypeAlias
from typing import TypeVar
from typing import cast
from typing import overload
from uuid import UUID

Expand Down Expand Up @@ -236,9 +238,9 @@ def to_protobuf(self) -> pb.task.Task:

def dispatch(self) -> W:
if isasyncgenfunction(self.callable):
return self._stream()
return cast(W, self._stream())
elif iscoroutinefunction(self.callable):
return self._run()
return cast(W, self._run())
else:
raise ValueError("Expected routine to be coroutine or async generator")

Expand All @@ -251,6 +253,7 @@ async def _run(self):
:raises RuntimeError:
If no proxy pool is available for task execution.
"""
assert iscoroutinefunction(self.callable), "Expected coroutine function"
proxy_pool = wool.__proxy_pool__.get()
if not proxy_pool:
raise RuntimeError("No proxy pool available for task execution")
Expand All @@ -274,6 +277,7 @@ async def _stream(self):
:raises RuntimeError:
If no proxy pool is available for task execution.
"""
assert isasyncgenfunction(self.callable), "Expected async generator function"
proxy_pool = wool.__proxy_pool__.get()
if not proxy_pool:
raise RuntimeError("No proxy pool available for task execution")
Expand Down
10 changes: 7 additions & 3 deletions wool/src/wool/runtime/worker/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import asyncio
from dataclasses import dataclass
from typing import AsyncGenerator
from typing import AsyncIterator
from typing import Final
from typing import Generic
from typing import TypeAlias
from typing import TypeVar
from typing import cast

import cloudpickle
import grpc.aio
Expand Down Expand Up @@ -288,7 +290,7 @@ async def dispatch(
task: Task,
*,
timeout: float | None = None,
) -> AsyncIterator[pb.task.Result]:
) -> AsyncGenerator[pb.task.Result, None]:
"""Dispatch a task to the remote worker for execution.

Sends the task to the worker via gRPC, waits for acknowledgment,
Expand Down Expand Up @@ -338,7 +340,7 @@ async def dispatch(
raise

await _channel_pool.release(self._key)
return gen
return cast(AsyncGenerator[pb.task.Result, None], gen)

async def close(self):
"""Close the connection and release all pooled resources.
Expand Down Expand Up @@ -375,7 +377,9 @@ async def _dispatch(self, ch, task, timeout):
raise
return call

async def _execute(self, call):
async def _execute(
self, call: _DispatchCall
) -> AsyncGenerator[pb.task.Result | None, None]:
ch = await _channel_pool.acquire(self._key)
try:
yield # Priming yield — signals dispatch() that ref is held
Expand Down
28 changes: 23 additions & 5 deletions wool/src/wool/runtime/worker/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def __init__(
*tags: str,
size: int = 0,
worker: WorkerFactory = LocalWorker,
discovery: DiscoveryLike | Factory[DiscoveryLike] | None = None,
discovery: None = None,
loadbalancer: (
LoadBalancerLike | Factory[LoadBalancerLike]
) = RoundRobinLoadBalancer,
credentials: WorkerCredentials | None | UndefinedType = Undefined,
):
"""
Create an ephemeral pool of workers, spawning the specified quantity of workers
using the specified worker factory.
Create an ephemeral pool of workers, spawning the specified
quantity of workers using the specified worker factory.
"""
...

Expand All @@ -183,8 +183,26 @@ def __init__(
credentials: WorkerCredentials | None | UndefinedType = Undefined,
):
"""
Connect to an existing pool of workers discovered by the specified discovery
protocol.
Connect to an existing pool of workers discovered by the
specified discovery protocol.
"""
...

@overload
def __init__(
self,
*tags: str,
size: int = 0,
worker: WorkerFactory = LocalWorker,
discovery: DiscoveryLike | Factory[DiscoveryLike],
loadbalancer: (
LoadBalancerLike | Factory[LoadBalancerLike]
) = RoundRobinLoadBalancer,
credentials: WorkerCredentials | None | UndefinedType = Undefined,
):
"""
Create a hybrid pool that spawns local workers and discovers
remote workers through the specified discovery protocol.
"""
...

Expand Down
Loading