From 63f8a85a636e7dcc3ae8b3496e3af756dcdbbb48 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:01:16 -0800 Subject: [PATCH 1/7] Update __init__.py,condition.py scheduler.py --- distributed/__init__.py | 1 + distributed/condition.py | 197 ++++++++++ distributed/scheduler.py | 826 ++++++++++----------------------------- 3 files changed, 400 insertions(+), 624 deletions(-) create mode 100644 distributed/condition.py diff --git a/distributed/__init__.py b/distributed/__init__.py index 3f075b977c..091c14e0eb 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -146,3 +146,4 @@ "widgets", "worker_client", ] +from distributed.condition import Condition diff --git a/distributed/condition.py b/distributed/condition.py new file mode 100644 index 0000000000..ad31630815 --- /dev/null +++ b/distributed/condition.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections import defaultdict +from contextlib import suppress + +from distributed.utils import log_errors, wait_for, TimeoutError +from distributed.utils import SyncMethodMixin +from distributed.worker import get_client + +logger = logging.getLogger(__name__) + + +class ConditionExtension: + """Scheduler extension for managing distributed Conditions""" + + def __init__(self, scheduler): + self.scheduler = scheduler + # {condition_name: asyncio.Condition} + self._conditions = {} + # {condition_name: set of waiter_ids} + self._waiters = defaultdict(set) + + self.scheduler.handlers.update( + { + "condition_wait": self.wait, + "condition_notify": self.notify, + "condition_acquire": self.acquire, + "condition_release": self.release, + } + ) + + def _get_condition(self, name): + if name not in self._conditions: + self._conditions[name] = asyncio.Condition() + return self._conditions[name] + + @log_errors + async def acquire(self, name=None, id=None): + """Acquire the underlying lock""" + condition = self._get_condition(name) + await condition.acquire() + return True + + @log_errors + async def release(self, name=None, id=None): + """Release the underlying lock""" + if name not in self._conditions: + return False + condition = self._conditions[name] + condition.release() + return True + + @log_errors + async def wait(self, name=None, id=None, timeout=None): + """Wait on condition""" + condition = self._get_condition(name) + self._waiters[name].add(id) + + try: + if timeout: + await asyncio.wait_for(condition.wait(), timeout=timeout) + else: + await condition.wait() + return True + except asyncio.TimeoutError: + return False + finally: + self._waiters[name].discard(id) + # Cleanup if no waiters + if not self._waiters[name]: + with suppress(KeyError): + del self._waiters[name] + with suppress(KeyError): + del self._conditions[name] + + @log_errors + def notify(self, name=None, n=1): + """Notify n waiters""" + if name not in self._conditions: + return 0 + condition = self._conditions[name] + condition.notify(n=n) + return min(n, len(self._waiters.get(name, []))) + + @log_errors + def notify_all(self, name=None): + """Notify all waiters""" + if name not in self._conditions: + return 0 + condition = self._conditions[name] + count = len(self._waiters.get(name, [])) + condition.notify_all() + return count + + +class Condition(SyncMethodMixin): + """Distributed Condition Variable + + Mimics asyncio.Condition API. Allows coordination between + distributed workers using wait/notify pattern. + + Examples + -------- + >>> from distributed import Condition + >>> condition = Condition('my-condition') + >>> async with condition: + ... await condition.wait() # Wait for notification + + >>> # In another worker/client + >>> condition = Condition('my-condition') + >>> async with condition: + ... condition.notify() # Wake one waiter + """ + + def __init__(self, name=None, scheduler_rpc=None, loop=None): + self._scheduler = scheduler_rpc + self._loop = loop + self.name = name or f"condition-{uuid.uuid4().hex}" + self.id = uuid.uuid4().hex + self._locked = False + + def _get_scheduler_rpc(self): + if self._scheduler: + return self._scheduler + try: + client = get_client() + return client.scheduler + except ValueError: + from distributed.worker import get_worker + + worker = get_worker() + return worker.scheduler + + async def acquire(self): + """Acquire underlying lock""" + scheduler = self._get_scheduler_rpc() + result = await scheduler.condition_acquire(name=self.name, id=self.id) + self._locked = result + return result + + async def release(self): + """Release underlying lock""" + if not self._locked: + raise RuntimeError("Cannot release un-acquired lock") + scheduler = self._get_scheduler_rpc() + await scheduler.condition_release(name=self.name, id=self.id) + self._locked = False + + async def wait(self, timeout=None): + """Wait until notified + + Must be called while lock is held. Releases lock and waits + for notify(), then reacquires lock before returning. + """ + if not self._locked: + raise RuntimeError("Cannot wait on un-acquired condition") + + scheduler = self._get_scheduler_rpc() + result = await scheduler.condition_wait(name=self.name, id=self.id, timeout=timeout) + return result + + async def notify(self, n=1): + """Wake up one or more waiters""" + if not self._locked: + raise RuntimeError("Cannot notify on un-acquired condition") + scheduler = self._get_scheduler_rpc() + return await scheduler.condition_notify(name=self.name, n=n) + + async def notify_all(self): + """Wake up all waiters""" + if not self._locked: + raise RuntimeError("Cannot notify on un-acquired condition") + scheduler = self._get_scheduler_rpc() + return await scheduler.condition_notify_all(name=self.name) + + def locked(self): + """Return True if lock is held""" + return self._locked + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.release() + + def __enter__(self): + return self.sync(self.__aenter__) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync(self.__aexit__, exc_type, exc_val, exc_tb) + + def __repr__(self): + return f"" diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ea5775aea6..e9b4bec324 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -144,6 +144,7 @@ scatter_to_workers, ) from distributed.variable import VariableExtension +from distributed.condition import ConditionExtension if TYPE_CHECKING: from typing import TypeAlias, TypeVar @@ -181,9 +182,7 @@ logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -DEFAULT_DATA_SIZE = parse_bytes( - dask.config.get("distributed.scheduler.default-data-size") -) +DEFAULT_DATA_SIZE = parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { @@ -194,6 +193,7 @@ "variables": VariableExtension, "semaphores": SemaphoreExtension, "events": EventExtension, + "conditions": ConditionExtension, "amm": ActiveMemoryManagerExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, @@ -407,8 +407,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return { k: getattr(self, k) for k in dir(self) - if not k.startswith("_") - and k not in {"sum", "managed_in_memory", "managed_spilled"} + if not k.startswith("_") and k not in {"sum", "managed_in_memory", "managed_spilled"} } @@ -580,9 +579,7 @@ def __hash__(self) -> int: return self._hash def __eq__(self, other: object) -> bool: - return self is other or ( - isinstance(other, WorkerState) and other.server_id == self.server_id - ) + return self is other or (isinstance(other, WorkerState) and other.server_id == self.server_id) @property def has_what(self) -> Set[TaskState]: @@ -833,9 +830,7 @@ def _dec_needs_replica(self, ts: TaskState) -> None: nbytes = ts.get_nbytes() # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min( - nbytes, self.scheduler._network_occ_global - ) + self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) def add_replica(self, ts: TaskState) -> None: """The worker acquired a replica of task""" @@ -848,18 +843,14 @@ def add_replica(self, ts: TaskState) -> None: del self.needs_what[ts] # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min( - nbytes, self.scheduler._network_occ_global - ) + self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) ts.who_has.add(self) self.nbytes += nbytes self._has_what[ts] = None @property def occupancy(self) -> float: - return self._occupancy_cache or self.scheduler._calc_occupancy( - self.task_prefix_count, self._network_occ - ) + return self._occupancy_cache or self.scheduler._calc_occupancy(self.task_prefix_count, self._network_occ) @dataclasses.dataclass @@ -921,9 +912,7 @@ def __repr__(self) -> str: return ( f"" ) @@ -981,10 +970,7 @@ def all_durations(self) -> defaultdict[str, float]: """Cumulative duration of all completed actions of tasks belonging to this collection, by action""" return defaultdict( float, - { - action: duration_us / 1e6 - for action, duration_us in self._all_durations_us.items() - }, + {action: duration_us / 1e6 for action, duration_us in self._all_durations_us.items()}, ) @property @@ -1089,13 +1075,7 @@ def active_states(self) -> dict[TaskStateState, int]: def __repr__(self) -> str: return ( - "<" - + self.name - + ": " - + ", ".join( - "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v - ) - + ">" + "<" + self.name + ": " + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + ">" ) @@ -1185,9 +1165,7 @@ def __repr__(self) -> str: "<" + (self.name or "no-group") + ": " - + ", ".join( - "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v - ) + + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + ">" ) @@ -1215,8 +1193,7 @@ def done(self) -> bool: recomputed. """ return all( - count == 0 or state in {"memory", "erred", "released", "forgotten"} - for state, count in self.states.items() + count == 0 or state in {"memory", "erred", "released", "forgotten"} for state, count in self.states.items() ) @@ -1775,15 +1752,9 @@ def __init__( self.resources = resources self.saturated = set() self.tasks = tasks - self.replicated_tasks = { - ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1 - } - self.computations = deque( - maxlen=dask.config.get("distributed.diagnostics.computations.max-history") - ) - self.erred_tasks = deque( - maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history") - ) + self.replicated_tasks = {ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1} + self.computations = deque(maxlen=dask.config.get("distributed.diagnostics.computations.max-history")) + self.erred_tasks = deque(maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history")) self.task_groups = {} self.task_prefixes = {} self.task_metadata = {} @@ -1796,61 +1767,38 @@ def __init__( self.workers = workers self._task_prefix_count_global = defaultdict(int) self._network_occ_global = 0 - self.running = { - ws for ws in self.workers.values() if ws.status == Status.running - } + self.running = {ws for ws in self.workers.values() if ws.status == Status.running} self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} - self.transition_log = deque( - maxlen=dask.config.get("distributed.admin.low-level-log-length") - ) + self.transition_log = deque(maxlen=dask.config.get("distributed.admin.low-level-log-length")) self.transition_counter = 0 self._idle_transition_counter = 0 self.transition_counter_max = transition_counter_max # Variables from dask.config, cached by __init__ for performance - self.UNKNOWN_TASK_DURATION = parse_timedelta( - dask.config.get("distributed.scheduler.unknown-task-duration") - ) + self.UNKNOWN_TASK_DURATION = parse_timedelta(dask.config.get("distributed.scheduler.unknown-task-duration")) self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( dask.config.get("distributed.worker.memory.recent-to-old-time") ) - self.MEMORY_REBALANCE_MEASURE = dask.config.get( - "distributed.worker.memory.rebalance.measure" - ) - self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( - "distributed.worker.memory.rebalance.sender-min" - ) - self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( - "distributed.worker.memory.rebalance.recipient-max" - ) + self.MEMORY_REBALANCE_MEASURE = dask.config.get("distributed.worker.memory.rebalance.measure") + self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get("distributed.worker.memory.rebalance.sender-min") + self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get("distributed.worker.memory.rebalance.recipient-max") self.MEMORY_REBALANCE_HALF_GAP = ( - dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") - / 2.0 + dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 ) - self.WORKER_SATURATION = dask.config.get( - "distributed.scheduler.worker-saturation" - ) + self.WORKER_SATURATION = dask.config.get("distributed.scheduler.worker-saturation") if self.WORKER_SATURATION == "inf": # Special case necessary because there's no way to parse a float infinity # from a DASK_* environment variable self.WORKER_SATURATION = math.inf - if ( - not isinstance(self.WORKER_SATURATION, (int, float)) - or self.WORKER_SATURATION <= 0 - ): + if not isinstance(self.WORKER_SATURATION, (int, float)) or self.WORKER_SATURATION <= 0: raise ValueError( # pragma: nocover - "`distributed.scheduler.worker-saturation` must be a float > 0; got " - + repr(self.WORKER_SATURATION) + "`distributed.scheduler.worker-saturation` must be a float > 0; got " + repr(self.WORKER_SATURATION) ) - self.rootish_tg_threshold = dask.config.get( - "distributed.scheduler.rootish-taskgroup" - ) - self.rootish_tg_dependencies_threshold = dask.config.get( - "distributed.scheduler.rootish-taskgroup-dependencies" - ) + self.rootish_tg_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup") + self.rootish_tg_dependencies_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup-dependencies") @abstractmethod def log_event(self, topic: str | Collection[str], msg: Any) -> None: ... @@ -1984,9 +1932,7 @@ def _calc_occupancy( # State Transitions # ##################### - def _transition( - self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any - ) -> RecsMsgs: + def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any) -> RecsMsgs: """Transition a key from its current state to the finish state Examples @@ -2032,15 +1978,11 @@ def _transition( func = self._TRANSITIONS_TABLE.get((start, finish)) if func is not None: - recommendations, client_msgs, worker_msgs = func( - self, key, stimulus_id, **kwargs - ) + recommendations, client_msgs, worker_msgs = func(self, key, stimulus_id, **kwargs) elif "released" not in (start, finish): assert not kwargs, (kwargs, start, finish) - a_recs, a_cmsgs, a_wmsgs = self._transition( - key, "released", stimulus_id - ) + a_recs, a_cmsgs, a_wmsgs = self._transition(key, "released", stimulus_id) v = a_recs.get(key, finish) # The inner rec has higher priority? Is that always desired? @@ -2070,16 +2012,10 @@ def _transition( stimulus_id = STIMULUS_ID_UNSET actual_finish = ts._state - self.transition_log.append( - Transition( - key, start, actual_finish, recommendations, stimulus_id, time() - ) - ) + self.transition_log.append(Transition(key, start, actual_finish, recommendations, stimulus_id, time())) if self.validate: if stimulus_id == STIMULUS_ID_UNSET: - raise RuntimeError( - "stimulus_id not set during Scheduler transition" - ) + raise RuntimeError("stimulus_id not set during Scheduler transition") logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2096,9 +2032,7 @@ def _transition( self.tasks[ts.key] = ts for plugin in list(self.plugins.values()): try: - plugin.transition( - key, start, actual_finish, stimulus_id=stimulus_id, **kwargs - ) + plugin.transition(key, start, actual_finish, stimulus_id=stimulus_id, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": @@ -2282,9 +2216,7 @@ def _transition_queued_erred( traceback_text=traceback_text, ) - def decide_worker_rootish_queuing_disabled( - self, ts: TaskState - ) -> WorkerState | None: + def decide_worker_rootish_queuing_disabled(self, ts: TaskState) -> WorkerState | None: """Pick a worker for a runnable root-ish task, without queuing. This attempts to schedule sibling tasks on the same worker, reducing future data @@ -2315,25 +2247,16 @@ def decide_worker_rootish_queuing_disabled( tg = ts.group lws = tg.last_worker - if ( - lws - and tg.last_worker_tasks_left - and lws.status == Status.running - and self.workers.get(lws.address) is lws - ): + if lws and tg.last_worker_tasks_left and lws.status == Status.running and self.workers.get(lws.address) is lws: ws = lws else: # Last-used worker is full, unknown, retiring, or paused; # pick a new worker for the next few tasks ws = min(pool, key=partial(self.worker_objective, ts)) - tg.last_worker_tasks_left = math.floor( - (len(tg) / self.total_nthreads) * ws.nthreads - ) + tg.last_worker_tasks_left = math.floor((len(tg) / self.total_nthreads) * ws.nthreads) # Record `last_worker`, or clear it on the final task - tg.last_worker = ( - ws if tg.states["released"] + tg.states["waiting"] > 1 else None - ) + tg.last_worker = ws if tg.states["released"] + tg.states["waiting"] > 1 else None tg.last_worker_tasks_left -= 1 if self.validate and ws is not None: @@ -2574,9 +2497,7 @@ def _transition_processing_memory( recommendations: Recs = {} client_msgs: Msgs = {} - self._add_to_memory( - ts, ws, recommendations, client_msgs, type=type, typename=typename - ) + self._add_to_memory(ts, ws, recommendations, client_msgs, type=type, typename=typename) if self.validate: assert not ts.processing_on @@ -2596,9 +2517,7 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: ws.actors.discard(ts) if ts.who_wants: ts.exception_blame = ts - ts.exception = Serialized( - *serialize(RuntimeError("Worker holding Actor was lost")) - ) + ts.exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) return {ts.key: "erred"}, {}, {} # don't try to recreate recommendations: Recs = {} @@ -2625,9 +2544,7 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: recommendations[key] = "forgotten" elif ts.has_lost_dependencies: recommendations[key] = "forgotten" - elif (ts.who_wants or ts.waiters) and not any( - dts.state == "erred" for dts in ts.dependencies - ): + elif (ts.who_wants or ts.waiters) and not any(dts.state == "erred" for dts in ts.dependencies): recommendations[key] = "waiting" for dts in ts.waiters or (): @@ -3007,9 +2924,7 @@ def _transition_memory_erred(self, key: Key, stimulus_id: str) -> RecsMsgs: if not dts.who_has: dts.exception_blame = ts recommendations[dts.key] = "erred" - exception = Serialized( - *serialize(RuntimeError("Worker holding Actor was lost")) - ) + exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) report_msg = { "op": "task-erred", "key": key, @@ -3122,14 +3037,9 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs ("released", "erred"): _transition_released_erred, } - def story( - self, *keys_or_tasks_or_stimuli: Key | TaskState | str - ) -> list[Transition]: + def story(self, *keys_or_tasks_or_stimuli: Key | TaskState | str) -> list[Transition]: """Get all transitions that touch one of the input keys or stimulus_id's""" - keys_or_stimuli = { - key.key if isinstance(key, TaskState) else key - for key in keys_or_tasks_or_stimuli - } + keys_or_stimuli = {key.key if isinstance(key, TaskState) else key for key in keys_or_tasks_or_stimuli} return scheduler_story(keys_or_stimuli, self.transition_log) ############################## @@ -3205,14 +3115,9 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: else: self.idle_task_count.discard(ws) - def is_unoccupied( - self, ws: WorkerState, occupancy: float, nprocessing: int - ) -> bool: + def is_unoccupied(self, ws: WorkerState, occupancy: float, nprocessing: int) -> bool: nthreads = ws.nthreads - return ( - nprocessing < nthreads - or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 - ) + return nprocessing < nthreads or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: """ @@ -3388,9 +3293,7 @@ def _validate_ready(self, ts: TaskState) -> None: assert ts not in self.queued assert all(dts.who_has for dts in ts.dependencies) - def _add_to_processing( - self, ts: TaskState, ws: WorkerState, stimulus_id: str - ) -> RecsMsgs: + def _add_to_processing(self, ts: TaskState, ws: WorkerState, stimulus_id: str) -> RecsMsgs: """Set a task as processing on a worker and return the worker messages to send""" if self.validate: self._validate_ready(ts) @@ -3408,11 +3311,7 @@ def _add_to_processing( ws.actors.add(ts) ndep_bytes = sum(dts.nbytes for dts in ts.dependencies) - if ( - ws.memory_limit - and ndep_bytes > ws.memory_limit - and dask.config.get("distributed.worker.memory.terminate") - ): + if ws.memory_limit and ndep_bytes > ws.memory_limit and dask.config.get("distributed.worker.memory.terminate"): # Note # ---- # This is a crude safety system, only meant to prevent order-of-magnitude @@ -3620,10 +3519,7 @@ def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: "run_id": ts.run_id, "priority": ts.priority, "stimulus_id": f"compute-task-{time()}", - "who_has": { - dts.key: tuple(ws.address for ws in (dts.who_has or ())) - for dts in ts.dependencies - }, + "who_has": {dts.key: tuple(ws.address for ws in (dts.who_has or ())) for dts in ts.dependencies}, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, "run_spec": ToPickle(ts.run_spec), "resource_restrictions": ts.resource_restrictions, @@ -3795,16 +3691,10 @@ def __init__( self.services = {} self.scheduler_file = scheduler_file - self.worker_ttl = parse_timedelta( - worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") - ) - self.idle_timeout = parse_timedelta( - idle_timeout or dask.config.get("distributed.scheduler.idle-timeout") - ) + self.worker_ttl = parse_timedelta(worker_ttl or dask.config.get("distributed.scheduler.worker-ttl")) + self.idle_timeout = parse_timedelta(idle_timeout or dask.config.get("distributed.scheduler.idle-timeout")) self.idle_since = time() - self.no_workers_timeout = parse_timedelta( - dask.config.get("distributed.scheduler.no-workers-timeout") - ) + self.no_workers_timeout = parse_timedelta(dask.config.get("distributed.scheduler.no-workers-timeout")) self._no_workers_since = None self.time_started = self.idle_since # compatibility for dask-gateway @@ -3852,24 +3742,17 @@ def __init__( except ImportError: show_dashboard = False http_server_modules.append("distributed.http.scheduler.missing_bokeh") - routes = get_handlers( - server=self, modules=http_server_modules, prefix=http_prefix - ) + routes = get_handlers(server=self, modules=http_server_modules, prefix=http_prefix) self.start_http_server(routes, dashboard_address, default_port=8787) self.jupyter = jupyter if show_dashboard: - distributed.dashboard.scheduler.connect( - self.http_application, self.http_server, self, prefix=http_prefix - ) + distributed.dashboard.scheduler.connect(self.http_application, self.http_server, self, prefix=http_prefix) scheduler = self if self.jupyter: try: from jupyter_server.serverapp import ServerApp except ImportError: - raise ImportError( - "In order to use the Dask jupyter option you " - "need to have jupyterlab installed" - ) + raise ImportError("In order to use the Dask jupyter option you need to have jupyterlab installed") from traitlets.config import Config """HTTP handler to shut down the Jupyter server. @@ -3917,9 +3800,7 @@ async def post(self) -> None: argv=[], ) self._jupyter_server_application = j - shutdown_app = tornado.web.Application( - [(r"/jupyter/api/shutdown", ShutdownHandler)] - ) + shutdown_app = tornado.web.Application([(r"/jupyter/api/shutdown", ShutdownHandler)]) shutdown_app.settings = j.web_app.settings self.http_application.add_application(shutdown_app) self.http_application.add_application(j.web_app) @@ -4139,8 +4020,7 @@ def identity(self, n_workers: int = -1) -> dict[str, Any]: "total_threads": self.total_nthreads, "total_memory": self.total_memory, "workers": { - worker.address: worker.identity() - for worker in itertools.islice(self.workers.values(), n_workers) + worker.address: worker.identity() for worker in itertools.islice(self.workers.values(), n_workers) }, } return d @@ -4198,10 +4078,7 @@ async def get_cluster_state( workers_future.cancel() # Convert any RPC errors to strings - worker_states = { - k: repr(v) if isinstance(v, Exception) else v - for k, v in worker_states.items() - } + worker_states = {k: repr(v) if isinstance(v, Exception) else v for k, v in worker_states.items()} return { "scheduler": scheduler_state, @@ -4217,9 +4094,7 @@ async def dump_cluster_state_to_url( **storage_options: dict[str, Any], ) -> None: "Write a cluster state dump to an fsspec-compatible URL." - await cluster_dump.write_state( - partial(self.get_cluster_state, exclude), url, format, **storage_options - ) + await cluster_dump.write_state(partial(self.get_cluster_state, exclude), url, format, **storage_options) def get_worker_service_addr( self, worker: str, service_name: str, protocol: bool = False @@ -4287,9 +4162,7 @@ async def start_unsafe(self) -> Self: # formatting dashboard link can fail if distributed.dashboard.link # refers to non-existent env vars. except KeyError as e: - logger.warning( - f"Failed to format dashboard link, unknown value: {e}" - ) + logger.warning(f"Failed to format dashboard link, unknown value: {e}") link = f":{server.port}" else: link = f"{listen_ip}:{server.port}" @@ -4315,9 +4188,7 @@ def del_scheduler_file() -> None: await self.listen("tcp://localhost:0") os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address - await asyncio.gather( - *[plugin.start(self) for plugin in list(self.plugins.values())] - ) + await asyncio.gather(*[plugin.start(self) for plugin in list(self.plugins.values())]) self.start_periodic_callbacks() @@ -4349,15 +4220,11 @@ async def log_errors(func: Callable) -> None: except Exception: logger.exception("Plugin call failed during scheduler.close") - await asyncio.gather( - *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] - ) + await asyncio.gather(*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]) await self.preloads.teardown() - await asyncio.gather( - *[log_errors(plugin.close) for plugin in list(self.plugins.values())] - ) + await asyncio.gather(*[log_errors(plugin.close) for plugin in list(self.plugins.values())]) for pc in self.periodic_callbacks.values(): pc.stop() @@ -4430,25 +4297,21 @@ def heartbeat_worker( dh["last-seen"] = local_now frac = 1 / len(self.workers) - self.bandwidth = ( - self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac - ) + self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: self.bandwidth_workers[address, other] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[ - address, other - ] * alpha + bw * (1 - alpha) + self.bandwidth_workers[address, other] = self.bandwidth_workers[address, other] * alpha + bw * ( + 1 - alpha + ) for typ, (bw, count) in metrics["bandwidth"]["types"].items(): if typ not in self.bandwidth_types: self.bandwidth_types[typ] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( - 1 - alpha - ) + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * (1 - alpha) ws.last_seen = local_now if executing is not None: @@ -4479,9 +4342,7 @@ def heartbeat_worker( # ws._nbytes is updated at a different time and sizeof() may not be accurate, # so size may be (temporarily) negative; floor it to zero. - size = max( - 0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"] - ) + size = max(0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"]) ws._memory_unmanaged_history.append((local_now, size)) if not memory_unmanaged_old: @@ -4626,9 +4487,7 @@ async def add_worker( logger.exception(exc, exc_info=exc) if ws.status == Status.running: - self.transitions( - self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id - ) + self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) logger.info("Register worker addr: %s name: %s", ws.address, ws.name) @@ -4812,9 +4671,7 @@ def _create_taskstate_from_graph( # _generate_taskstates is not the only thing that calls new_task(). A # TaskState may have also been created by client_desires_keys or scatter, # and only later gained a run_spec. - span_annotations = spans_ext.observe_tasks( - touched_tasks, span_metadata=span_metadata, code=code - ) + span_annotations = spans_ext.observe_tasks(touched_tasks, span_metadata=span_metadata, code=code) # In case of TaskGroup collision, spans may have changed # FIXME: Is this used anywhere besides tests? if span_annotations: @@ -4920,9 +4777,7 @@ async def update_graph( }, client=client, ) - self.client_releases_keys( - keys=keys, client=client, stimulus_id=stimulus_id - ) + self.client_releases_keys(keys=keys, client=client, stimulus_id=stimulus_id) evt_msg = { "action": "update-graph", "stimulus_id": stimulus_id, @@ -4955,8 +4810,7 @@ async def update_graph( "start_timestamp_seconds": start, "materialization_duration_seconds": materialization_done - start, "ordering_duration_seconds": materialization_done - ordering_done, - "state_initialization_duration_seconds": ordering_done - - task_state_created, + "state_initialization_duration_seconds": ordering_done - task_state_created, "duration_seconds": task_state_created - start, } ) @@ -5209,9 +5063,7 @@ def _set_priorities( ) if self.validate and istask(ts.run_spec): - assert isinstance(ts.priority, tuple) and all( - isinstance(el, (int, float)) for el in ts.priority - ) + assert isinstance(ts.priority, tuple) and all(isinstance(el, (int, float)) for el in ts.priority) def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools @@ -5230,10 +5082,7 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """ if not self.queued: return - slots_available = sum( - _task_slots_available(ws, self.WORKER_SATURATION) - for ws in self.idle_task_count - ) + slots_available = sum(_task_slots_available(ws, self.WORKER_SATURATION) for ws in self.idle_task_count) if slots_available == 0: return @@ -5255,9 +5104,7 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: assert qts.state == "processing" assert not self.queued or self.queued.peek() != qts - def stimulus_task_finished( - self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any - ) -> RecsMsgs: + def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any) -> RecsMsgs: """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s[%d] %s", key, run_id, worker) @@ -5268,8 +5115,7 @@ def stimulus_task_finished( ts = self.tasks.get(key) if ts is None or ts.state in ("released", "queued", "no-worker"): logger.debug( - "Received already computed task, worker: %s, state: %s" - ", key: %s, who_has: %s", + "Received already computed task, worker: %s, state: %s, key: %s, who_has: %s", worker, ts.state if ts else "forgotten", key, @@ -5284,7 +5130,7 @@ def stimulus_task_finished( ] elif ts.state == "erred": logger.debug( - "Received already erred task, worker: %s" ", key: %s", + "Received already erred task, worker: %s, key: %s", worker, key, ) @@ -5361,9 +5207,7 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry( - self, keys: Collection[Key], client: str | None = None - ) -> tuple[Key, ...]: + def stimulus_retry(self, keys: Collection[Key], client: str | None = None) -> tuple[Key, ...]: logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5441,14 +5285,10 @@ async def remove_worker( ws = self.workers[address] - logger.info( - f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})" - ) + logger.info(f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})") if close: with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send( - {"op": "close", "reason": "scheduler-remove-worker"} - ) + self.stream_comms[address].send({"op": "close", "reason": "scheduler-remove-worker"}) self.remove_resources(address) @@ -5502,8 +5342,7 @@ async def remove_worker( ) recommendations.update(r) logger.error( - "Task %s marked as failed because %d workers died" - " while trying to run it", + "Task %s marked as failed because %d workers died while trying to run it", ts.key, ts.suspicious, ) @@ -5556,9 +5395,7 @@ async def remove_worker( for plugin in list(self.plugins.values()): try: try: - result = plugin.remove_worker( - scheduler=self, worker=address, stimulus_id=stimulus_id - ) + result = plugin.remove_worker(scheduler=self, worker=address, stimulus_id=stimulus_id) except TypeError: parameters = inspect.signature(plugin.remove_worker).parameters if "stimulus_id" not in parameters and not any( @@ -5590,13 +5427,9 @@ async def remove_worker_from_events() -> None: if address not in self.workers: self._broker.truncate(address) - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) + cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) - self._ongoing_background_tasks.call_later( - cleanup_delay, remove_worker_from_events - ) + self._ongoing_background_tasks.call_later(cleanup_delay, remove_worker_from_events) logger.debug("Removed worker %s", ws) for w in self.workers: @@ -5611,9 +5444,7 @@ async def remove_worker_from_events() -> None: return "OK" - def stimulus_cancel( - self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str - ) -> None: + def stimulus_cancel(self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str) -> None: """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) self.log_event(client, {"action": "cancel", "count": len(keys), "force": force}) @@ -5675,9 +5506,7 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None: if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys( - self, keys: Collection[Key], client: str, stimulus_id: str | None = None - ) -> None: + def client_releases_keys(self, keys: Collection[Key], client: str, stimulus_id: str | None = None) -> None: """Remove keys from client desired list""" stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): @@ -5736,9 +5565,7 @@ def validate_queued(self, key: Key) -> None: assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on - assert not ( - ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions - ) + assert not (ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions) for dts in ts.dependencies: assert dts.who_has assert ts in (dts.waiters or ()) @@ -5764,9 +5591,7 @@ def validate_memory(self, key: Key) -> None: assert ts not in self.unrunnable assert ts not in self.queued for dts in ts.dependents: - assert (dts in (ts.waiters or ())) == ( - dts.state in ("waiting", "queued", "processing", "no-worker") - ) + assert (dts in (ts.waiters or ())) == (dts.state in ("waiting", "queued", "processing", "no-worker")) assert ts not in (dts.waiting_on or ()) def validate_no_worker(self, key: Key) -> None: @@ -5797,9 +5622,7 @@ def validate_key(self, key: Key, ts: TaskState | None = None) -> None: try: func = getattr(self, "validate_" + ts.state.replace("-", "_")) except AttributeError: - logger.error( - "self.validate_%s not found", ts.state.replace("-", "_") - ) + logger.error("self.validate_%s not found", ts.state.replace("-", "_")) else: func(key) except Exception as e: @@ -5865,9 +5688,9 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert task_prefix_counts.keys() == self._task_prefix_count_global.keys() for name, global_count in self._task_prefix_count_global.items(): - assert ( - task_prefix_counts[name] == global_count - ), f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" + assert task_prefix_counts[name] == global_count, ( + f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" + ) for ws in self.running: assert ws.status == Status.running @@ -5890,10 +5713,7 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert cs.client_key == c a = {w: ws.nbytes for w, ws in self.workers.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws.has_what) - for w, ws in self.workers.items() - } + b = {w: sum(ts.get_nbytes() for ts in ws.has_what) for w, ws in self.workers.items()} assert a == b, (a, b) if self.transition_counter_max: @@ -5903,9 +5723,7 @@ def validate_state(self, allow_overlap: bool = False) -> None: # Manage Messages # ################### - def report( - self, msg: dict, ts: TaskState | None = None, client: str | None = None - ) -> None: + def report(self, msg: dict, ts: TaskState | None = None, client: str | None = None) -> None: """ Publish updates to all listening Queues and Comms @@ -5927,9 +5745,7 @@ def report( # Notify clients interested in key (including `client`) # Note that, if report() was called by update_graph(), `client` won't be in # ts.who_wants yet. - client_keys = [ - cs.client_key for cs in ts.who_wants or () if cs.client_key != client - ] + client_keys = [cs.client_key for cs in ts.who_wants or () if cs.client_key != client] if client is not None: client_keys.append(client) @@ -5942,13 +5758,9 @@ def report( # logger.debug("Scheduler sends message to client %s: %s", k, msg) except CommClosedError: if self.status == Status.running: - logger.critical( - "Closed comm %r while trying to write %s", c, msg, exc_info=True - ) + logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) - async def add_client( - self, comm: Comm, client: str, versions: dict[str, Any] - ) -> None: + async def add_client(self, comm: Comm, client: str, versions: dict[str, Any]) -> None: """Add client to network We listen to all future messages from this Comm. @@ -6026,13 +5838,9 @@ async def remove_client_from_events() -> None: if client not in self.clients: self._broker.truncate(client) - cleanup_delay = parse_timedelta( - dask.config.get("distributed.scheduler.events-cleanup-delay") - ) + cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) if not self._ongoing_background_tasks.closed: - self._ongoing_background_tasks.call_later( - cleanup_delay, remove_client_from_events - ) + self._ongoing_background_tasks.call_later(cleanup_delay, remove_client_from_events) def send_task_to_worker(self, worker: str, ts: TaskState) -> None: """Send a single computational task to a worker""" @@ -6050,17 +5858,13 @@ def send_task_to_worker(self, worker: str, ts: TaskState) -> None: def handle_uncaught_error(self, **msg: Any) -> None: logger.exception(clean_exception(**msg)[1]) - def handle_task_finished( - self, key: Key, worker: str, stimulus_id: str, **msg: Any - ) -> None: + def handle_task_finished(self, key: Key, worker: str, stimulus_id: str, **msg: Any) -> None: if worker not in self.workers: return if self.validate: self.validate_key(key) - r: tuple = self.stimulus_task_finished( - key=key, worker=worker, stimulus_id=stimulus_id, **msg - ) + r: tuple = self.stimulus_task_finished(key=key, worker=worker, stimulus_id=stimulus_id, **msg) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) @@ -6099,9 +5903,7 @@ def handle_long_running( duration accounting as if the task has stopped. """ if worker not in self.workers: - logger.debug( - "Received long-running signal from unknown worker %s. Ignoring.", worker - ) + logger.debug("Received long-running signal from unknown worker %s. Ignoring.", worker) return if key not in self.tasks: @@ -6139,9 +5941,7 @@ def handle_long_running( self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) - def handle_worker_status_change( - self, status: str | Status, worker: str | WorkerState, stimulus_id: str - ) -> None: + def handle_worker_status_change(self, status: str | Status, worker: str | WorkerState, stimulus_id: str) -> None: ws = self.workers.get(worker) if isinstance(worker, str) else worker if not ws: return @@ -6164,9 +5964,7 @@ def handle_worker_status_change( if ws.status == Status.running: self.running.add(ws) self.check_idle_saturated(ws) - self.transitions( - self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id - ) + self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) @@ -6175,9 +5973,7 @@ def handle_worker_status_change( self.saturated.discard(ws) self._refresh_no_workers_since() - def handle_request_refresh_who_has( - self, keys: Iterable[Key], worker: str, stimulus_id: str - ) -> None: + def handle_request_refresh_who_has(self, keys: Iterable[Key], worker: str, stimulus_id: str) -> None: """Request from a Worker to refresh the who_has for some keys. Not to be confused with scheduler.who_has, which is a dedicated comm RPC request from a Client. @@ -6226,9 +6022,7 @@ async def handle_worker(self, comm: Comm, worker: str) -> None: finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker( - worker, stimulus_id=f"handle-worker-cleanup-{time()}" - ) + await self.remove_worker(worker, stimulus_id=f"handle-worker-cleanup-{time()}") def add_plugin( self, @@ -6289,9 +6083,7 @@ def remove_plugin(self, name: str | None = None) -> None: try: del self.plugins[name] except KeyError: - raise ValueError( - f"Could not find plugin {name!r} among the current scheduler plugins" - ) + raise ValueError(f"Could not find plugin {name!r} among the current scheduler plugins") async def register_scheduler_plugin( self, @@ -6353,9 +6145,7 @@ def client_send(self, client: str, msg: dict) -> None: c.send(msg) except CommClosedError: if self.status == Status.running: - logger.critical( - "Closed comm %r while trying to write %s", c, msg, exc_info=True - ) + logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None: """Send messages to client and workers""" @@ -6433,14 +6223,10 @@ async def scatter( n = len(workers) if broadcast is True else broadcast await self.replicate(keys=keys, workers=workers, n=n) - self.log_event( - [client, "all"], {"action": "scatter", "client": client, "count": len(data)} - ) + self.log_event([client, "all"], {"action": "scatter", "client": client, "count": len(data)}) return keys - async def gather( - self, keys: Collection[Key], serializers: list[str] | None = None - ) -> dict[Key, object]: + async def gather(self, keys: Collection[Key], serializers: list[str] | None = None) -> dict[Key, object]: """Collect data from workers to the scheduler""" data = {} missing_keys = list(keys) @@ -6461,9 +6247,7 @@ async def gather( missing_keys, new_failed_keys, new_missing_workers, - ) = await gather_from_workers( - who_has, rpc=self.rpc, serializers=serializers - ) + ) = await gather_from_workers(who_has, rpc=self.rpc, serializers=serializers) data.update(new_data) failed_keys += new_failed_keys missing_workers.update(new_missing_workers) @@ -6473,10 +6257,7 @@ async def gather( if not failed_keys: return {"status": "OK", "data": data} - failed_states = { - key: self.tasks[key].state if key in self.tasks else "forgotten" - for key in failed_keys - } + failed_states = {key: self.tasks[key].state if key in self.tasks else "forgotten" for key in failed_keys} logger.error("Couldn't gather keys: %s", failed_states) return {"status": "error", "keys": list(failed_keys)} @@ -6586,24 +6367,16 @@ async def restart_workers( workers = list(set(workers).intersection(self.workers)) logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}") - nanny_workers = { - addr: self.workers[addr].nanny - for addr in workers - if self.workers[addr].nanny - } + nanny_workers = {addr: self.workers[addr].nanny for addr in workers if self.workers[addr].nanny} # Close non-Nanny workers. We have no way to restart them, so we just let them # go, and assume a deployment system is going to restart them for us. no_nanny_workers = [addr for addr in workers if addr not in nanny_workers] if no_nanny_workers: logger.warning( - f"Workers {no_nanny_workers} do not use a nanny and will be terminated " - "without restarting them" + f"Workers {no_nanny_workers} do not use a nanny and will be terminated without restarting them" ) await asyncio.gather( - *( - self.remove_worker(address=addr, stimulus_id=stimulus_id) - for addr in no_nanny_workers - ) + *(self.remove_worker(address=addr, stimulus_id=stimulus_id) for addr in no_nanny_workers) ) out: dict[str, Literal["OK", "removed", "timed out"]] out = {addr: "removed" for addr in no_nanny_workers} @@ -6613,9 +6386,7 @@ async def restart_workers( async with contextlib.AsyncExitStack() as stack: nannies = await asyncio.gather( *( - stack.enter_async_context( - rpc(nanny_address, connection_args=self.connection_args) - ) + stack.enter_async_context(rpc(nanny_address, connection_args=self.connection_args)) for nanny_address in nanny_workers.values() ) ) @@ -6651,16 +6422,8 @@ async def restart_workers( raise resp if bad_nannies: - logger.error( - f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " - "force closing" - ) - await asyncio.gather( - *( - self.remove_worker(addr, stimulus_id=stimulus_id) - for addr in bad_nannies - ) - ) + logger.error(f"Workers {list(bad_nannies)} did not shut down within {timeout}s; force closing") + await asyncio.gather(*(self.remove_worker(addr, stimulus_id=stimulus_id) for addr in bad_nannies)) if on_error == "raise": raise TimeoutError( f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not " @@ -6669,15 +6432,10 @@ async def restart_workers( if client: self.log_event(client, {"action": "restart-workers", "workers": workers}) - self.log_event( - "all", {"action": "restart-workers", "workers": workers, "client": client} - ) + self.log_event("all", {"action": "restart-workers", "workers": workers, "client": client}) if not wait_for_workers: - logger.info( - "Workers restart finished (did not wait for new workers) " - f"({stimulus_id=}" - ) + logger.info(f"Workers restart finished (did not wait for new workers) ({stimulus_id=}") return out # NOTE: if new (unrelated) workers join while we're waiting, we may return @@ -6746,9 +6504,7 @@ async def broadcast( ERROR = object() - reuse_broadcast_comm = dask.config.get( - "distributed.scheduler.reuse-broadcast-comm", False - ) + reuse_broadcast_comm = dask.config.get("distributed.scheduler.reuse-broadcast-comm", False) close = not reuse_broadcast_comm async def send_message(addr: str) -> Any: @@ -6756,9 +6512,7 @@ async def send_message(addr: str) -> Any: comm = await self.rpc.connect(addr) comm.name = "Scheduler Broadcast" try: - resp = await send_recv( - comm, close=close, serializers=serializers, **msg - ) + resp = await send_recv(comm, close=close, serializers=serializers, **msg) finally: self.rpc.reuse(addr, comm) return resp @@ -6774,8 +6528,7 @@ async def send_message(addr: str) -> Any: return ERROR else: raise ValueError( - "on_error must be 'raise', 'return', 'return_pickle', " - f"or 'ignore'; got {on_error!r}" + f"on_error must be 'raise', 'return', 'return_pickle', or 'ignore'; got {on_error!r}" ) results = await All([send_message(address) for address in addresses]) @@ -6791,9 +6544,7 @@ async def proxy( d = await self.broadcast(msg=msg, workers=[worker], serializers=serializers) return d[worker] - async def gather_on_worker( - self, worker_address: str, who_has: dict[Key, list[str]] - ) -> set: + async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[str]]) -> set: """Peer-to-peer copy of keys from multiple workers to a single worker Parameters @@ -6809,15 +6560,12 @@ async def gather_on_worker( set of keys that failed to be copied """ try: - result = await retry_operation( - self.rpc(addr=worker_address).gather, who_has=who_has - ) + result = await retry_operation(self.rpc(addr=worker_address).gather, who_has=who_has) except OSError as e: # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" ) return set(who_has) @@ -6832,9 +6580,7 @@ async def gather_on_worker( elif result["status"] == "partial-fail": keys_failed = set(result["keys"]) keys_ok = who_has.keys() - keys_failed - logger.warning( - f"Worker {worker_address} failed to acquire keys: {result['keys']}" - ) + logger.warning(f"Worker {worker_address} failed to acquire keys: {result['keys']}") else: # pragma: nocover raise ValueError(f"Unexpected message from {worker_address}: {result}") @@ -6847,9 +6593,7 @@ async def gather_on_worker( return keys_failed - async def delete_worker_data( - self, worker_address: str, keys: Collection[Key], stimulus_id: str - ) -> None: + async def delete_worker_data(self, worker_address: str, keys: Collection[Key], stimulus_id: str) -> None: """Delete data from a worker and update the corresponding worker/task states Parameters @@ -6869,8 +6613,7 @@ async def delete_worker_data( # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" ) return @@ -6976,9 +6719,7 @@ async def rebalance( keys = set(keys) # unless already a set-like if not keys: return {"status": "OK"} - missing_data = [ - k for k in keys if k not in self.tasks or not self.tasks[k].who_has - ] + missing_data = [k for k in keys if k not in self.tasks or not self.tasks[k].who_has] if missing_data: return {"status": "partial-fail", "keys": missing_data} @@ -7053,9 +6794,7 @@ def _rebalance_find_msgs( # unmanaged memory that appeared over the last 30 seconds # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. - memory_by_worker = [ - (ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers - ] + memory_by_worker = [(ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: @@ -7068,18 +6807,12 @@ def _rebalance_find_msgs( sender_min = 0.0 recipient_max = math.inf - if ( - ws._has_what - and ws_memory >= mean_memory + half_gap - and ws_memory >= sender_min - ): + if ws._has_what and ws_memory >= mean_memory + half_gap and ws_memory >= sender_min: # This may send the worker below sender_min (by design) snd_bytes_max = mean_memory - ws_memory # negative snd_bytes_min = snd_bytes_max + half_gap # negative # See definition of senders above - senders.append( - (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what)) - ) + senders.append((snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what))) elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max: # This may send the worker above recipient_max (by design) rec_bytes_max = ws_memory - mean_memory # negative @@ -7197,9 +6930,7 @@ async def _rebalance_move_data( FIXME this method is not robust when the cluster is not idle. """ # {recipient address: {key: [sender address, ...]}} - to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict( - lambda: defaultdict(list) - ) + to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict(lambda: defaultdict(list)) for snd_ws, rec_ws, ts in msgs: to_recipients[rec_ws.address][ts.key].append(snd_ws.address) failed_keys_by_recipient = dict( @@ -7221,9 +6952,7 @@ async def _rebalance_move_data( to_senders[snd_ws.address].append(ts.key) # Note: this never raises exceptions - await asyncio.gather( - *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items()) - ) + await asyncio.gather(*(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -7302,17 +7031,13 @@ async def replicate( assert ts.who_has is not None del_candidates = tuple(ts.who_has & workers) if len(del_candidates) > n: - for ws in random.sample( - del_candidates, len(del_candidates) - n - ): + for ws in random.sample(del_candidates, len(del_candidates) - n): del_worker_tasks[ws].add(ts) # Note: this never raises exceptions await asyncio.gather( *[ - self.delete_worker_data( - ws.address, [t.key for t in tasks], stimulus_id - ) + self.delete_worker_data(ws.address, [t.key for t in tasks], stimulus_id) for ws, tasks in del_worker_tasks.items() ] ) @@ -7337,9 +7062,7 @@ async def replicate( assert count > 0 for ws in random.sample(tuple(workers - ts.who_has), count): - gathers[ws.address][ts.key] = [ - wws.address for wws in ts.who_has - ] + gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] await asyncio.gather( *( @@ -7595,8 +7318,7 @@ async def retire_workers( raise TypeError("names and workers are mutually exclusive") if (names is not None or workers is not None) and kwargs: raise TypeError( - "Parameters for workers_to_close() are mutually exclusive with " - f"names and workers: {kwargs}" + f"Parameters for workers_to_close() are mutually exclusive with names and workers: {kwargs}" ) stimulus_id = stimulus_id or f"retire-workers-{time()}" @@ -7616,24 +7338,16 @@ async def retire_workers( stimulus_id, workers, ) - wss = { - self.workers[address] - for address in workers - if address in self.workers - } + wss = {self.workers[address] for address in workers if address in self.workers} else: - wss = { - self.workers[address] for address in self.workers_to_close(**kwargs) - } + wss = {self.workers[address] for address in self.workers_to_close(**kwargs)} if not wss: return {} stop_amm = False amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm") if not amm or not amm.running: - amm = ActiveMemoryManagerExtension( - self, policies=set(), register=False, start=True, interval=2.0 - ) + amm = ActiveMemoryManagerExtension(self, policies=set(), register=False, start=True, interval=2.0) stop_amm = True try: @@ -7645,9 +7359,7 @@ async def retire_workers( # Change Worker.status to closing_gracefully. Immediately set # the same on the scheduler to prevent race conditions. prev_status = ws.status - self.handle_worker_status_change( - Status.closing_gracefully, ws, stimulus_id - ) + self.handle_worker_status_change(Status.closing_gracefully, ws, stimulus_id) # FIXME: We should send a message to the nanny first; # eventually workers won't be able to close their own nannies. self.stream_comms[ws.address].send( @@ -7738,14 +7450,10 @@ async def _track_retire_worker( ) return ws.address, "no-recipients", ws.identity() - logger.debug( - f"All unique keys on worker {ws.address!r} have been replicated elsewhere" - ) + logger.debug(f"All unique keys on worker {ws.address!r} have been replicated elsewhere") if remove: - await self.remove_worker( - ws.address, expected=True, close=close, stimulus_id=stimulus_id - ) + await self.remove_worker(ws.address, expected=True, close=close, stimulus_id=stimulus_id) elif close: self.close_worker(ws.address) @@ -7887,9 +7595,7 @@ async def feed( if teardown: teardown(self, state) # type: ignore - def log_worker_event( - self, worker: str, topic: str | Collection[str], msg: Any - ) -> None: + def log_worker_event(self, worker: str, topic: str | Collection[str], msg: Any) -> None: if isinstance(msg, dict) and worker != topic: msg["worker"] = worker self.log_event(topic, msg) @@ -7902,46 +7608,25 @@ def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]: del v["last_seen"] return ident - def get_processing( - self, workers: Iterable[str] | None = None - ) -> dict[str, list[Key]]: + def get_processing(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: if workers is not None: workers = set(map(self.coerce_address, workers)) return {w: [ts.key for ts in self.workers[w].processing] for w in workers} else: - return { - w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() - } + return {w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()} def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]: if keys is not None: return { - key: ( - [ws.address for ws in self.tasks[key].who_has or ()] - if key in self.tasks - else [] - ) - for key in keys + key: ([ws.address for ws in self.tasks[key].who_has or ()] if key in self.tasks else []) for key in keys } else: - return { - key: [ws.address for ws in ts.who_has or ()] - for key, ts in self.tasks.items() - } + return {key: [ws.address for ws in ts.who_has or ()] for key, ts in self.tasks.items()} - def get_has_what( - self, workers: Iterable[str] | None = None - ) -> dict[str, list[Key]]: + def get_has_what(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: if workers is not None: workers = map(self.coerce_address, workers) - return { - w: ( - [ts.key for ts in self.workers[w].has_what] - if w in self.workers - else [] - ) - for w in workers - } + return {w: ([ts.key for ts in self.workers[w].has_what] if w in self.workers else []) for w in workers} else: return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()} @@ -7952,13 +7637,9 @@ def get_ncores(self, workers: Iterable[str] | None = None) -> dict[str, int]: else: return {w: ws.nthreads for w, ws in self.workers.items()} - def get_ncores_running( - self, workers: Iterable[str] | None = None - ) -> dict[str, int]: + def get_ncores_running(self, workers: Iterable[str] | None = None) -> dict[str, int]: ncores = self.get_ncores(workers=workers) - return { - w: n for w, n in ncores.items() if self.workers[w].status == Status.running - } + return {w: n for w, n in ncores.items() if self.workers[w].status == Status.running} async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, Any]: workers: dict[str, list[Key] | None] @@ -7985,9 +7666,7 @@ async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, A if not workers: return {} - results = await asyncio.gather( - *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) - ) + results = await asyncio.gather(*(self.rpc(w).call_stack(keys=v) for w, v in workers.items())) response = {w: r for w, r in zip(workers, results) if r} return response @@ -8028,9 +7707,7 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: # implementing logic based on IP addresses would not necessarily help. # Randomize the connections to even out the mean measures. random.shuffle(workers) - futures = [ - self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers) - ] + futures = [self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers)] responses = await asyncio.gather(*futures) for d in responses: @@ -8039,17 +7716,12 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: result = {} for mode in out: - result[mode] = { - size: sum(durations) / len(durations) - for size, durations in out[mode].items() - } + result[mode] = {size: sum(durations) / len(durations) for size, durations in out[mode].items()} return result @log_errors - def get_nbytes( - self, keys: Iterable[Key] | None = None, summary: bool = True - ) -> dict[Key, int]: + def get_nbytes(self, keys: Iterable[Key] | None = None, summary: bool = True) -> dict[Key, int]: if keys is not None: result = {k: self.tasks[k].nbytes for k in keys} else: @@ -8135,9 +7807,7 @@ def get_task_prefix_states(self) -> dict[str, dict[str, int]]: return state def get_task_status(self, keys: Iterable[Key]) -> dict[Key, TaskStateState | None]: - return { - key: (self.tasks[key].state if key in self.tasks else None) for key in keys - } + return {key: (self.tasks[key].state if key in self.tasks else None) for key in keys} def get_task_stream( self, @@ -8160,14 +7830,11 @@ def start_task_metadata(self, name: str) -> None: def stop_task_metadata(self, name: str | None = None) -> dict: plugins = [ - p - for p in list(self.plugins.values()) - if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name + p for p in list(self.plugins.values()) if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name ] if len(plugins) != 1: raise ValueError( - "Expected to find exactly one CollectTaskMetaDataPlugin " - f"with name {name} but found {len(plugins)}." + f"Expected to find exactly one CollectTaskMetaDataPlugin with name {name} but found {len(plugins)}." ) plugin = plugins[0] @@ -8192,14 +7859,10 @@ async def register_worker_plugin( self.worker_plugins[name] = plugin - responses = await self.broadcast( - msg=dict(op="plugin-add", plugin=plugin, name=name) - ) + responses = await self.broadcast(msg=dict(op="plugin-add", plugin=plugin, name=name)) return responses - async def unregister_worker_plugin( - self, comm: None, name: str - ) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_worker_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.worker_plugins.pop(name) @@ -8231,27 +7894,21 @@ async def register_nanny_plugin( async with self._starting_nannies_cond: if self._starting_nannies: logger.info("Waiting for Nannies to start %s", self._starting_nannies) - await self._starting_nannies_cond.wait_for( - lambda: not self._starting_nannies - ) + await self._starting_nannies_cond.wait_for(lambda: not self._starting_nannies) responses = await self.broadcast( msg=dict(op="plugin_add", plugin=plugin, name=name), nanny=True, ) return responses - async def unregister_nanny_plugin( - self, comm: None, name: str - ) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_nanny_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.nanny_plugins.pop(name) except KeyError: raise ValueError(f"The nanny plugin {name} does not exist") - responses = await self.broadcast( - msg=dict(op="plugin_remove", name=name), nanny=True - ) + responses = await self.broadcast(msg=dict(op="plugin_remove", name=name), nanny=True) return responses def transition( @@ -8276,9 +7933,7 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recommendations, client_msgs, worker_msgs = self._transition( - key, finish, stimulus_id, **kwargs - ) + recommendations, client_msgs, worker_msgs = self._transition(key, finish, stimulus_id, **kwargs) self.send_all(client_msgs, worker_msgs) return recommendations @@ -8301,9 +7956,7 @@ async def get_story(self, keys_or_stimuli: Iterable[Key | str]) -> list[Transiti """ return self.story(*keys_or_stimuli) - def _reschedule( - self, key: Key, worker: str | None = None, *, stimulus_id: str - ) -> None: + def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) -> None: """Reschedule a task. This function should only be used when the task has already been released in @@ -8315,8 +7968,7 @@ def _reschedule( ts = self.tasks[key] except KeyError: logger.warning( - f"Attempting to reschedule task {key!r}, which was not " - "found on the scheduler. Aborting reschedule." + f"Attempting to reschedule task {key!r}, which was not found on the scheduler. Aborting reschedule." ) return if ts.state != "processing": @@ -8331,9 +7983,7 @@ def _reschedule( # Utility functions # ##################### - def add_resources( - self, worker: str, resources: dict | None = None - ) -> Literal["OK"]: + def add_resources(self, worker: str, resources: dict | None = None) -> Literal["OK"]: ws = self.workers[worker] if resources: ws.resources.update(resources) @@ -8415,10 +8065,7 @@ async def get_profile( ) results = await asyncio.gather( - *( - self.rpc(w).profile(start=start, stop=stop, key=key, server=server) - for w in workers - ), + *(self.rpc(w).profile(start=start, stop=stop, key=key, server=server) for w in workers), return_exceptions=True, ) @@ -8438,9 +8085,7 @@ async def get_profile_metadata( stop: float | None = None, profile_cycle_interval: str | float | None = None, ) -> dict[str, Any]: - dt = profile_cycle_interval or dask.config.get( - "distributed.worker.profile.cycle" - ) + dt = profile_cycle_interval or dask.config.get("distributed.worker.profile.cycle") dt = parse_timedelta(dt, default="ms") if workers is None: @@ -8463,9 +8108,7 @@ async def get_profile_metadata( ) ] - keys: dict[Key, list[list]] = { - k: [] for v in results for t, d in v["keys"] for k in d - } + keys: dict[Key, list[list]] = {k: [] for v in results for t, d in v["keys"] for k in d} groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) @@ -8482,9 +8125,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report( - self, start: float, last_count: int, code: str = "", mode: str | None = None - ) -> str: + async def performance_report(self, start: float, last_count: int, code: str = "", mode: str | None = None) -> str: stop = time() # Profiles compute_d, scheduler_d, workers_d = await asyncio.gather( @@ -8501,9 +8142,7 @@ def profile_to_figure(state: object) -> object: figure, source = profile.plot_figure(data, sizing_mode="stretch_both") return figure - compute, scheduler, workers = map( - profile_to_figure, (compute_d, scheduler_d, workers_d) - ) + compute, scheduler, workers = map(profile_to_figure, (compute_d, scheduler_d, workers_d)) del compute_d, scheduler_d, workers_d # Task stream @@ -8587,16 +8226,10 @@ def profile_to_figure(state: object) -> object: html = TabPanel(child=html, title="Summary") compute = TabPanel(child=compute, title="Worker Profile (compute)") workers = TabPanel(child=workers, title="Worker Profile (administrative)") - scheduler = TabPanel( - child=scheduler, title="Scheduler Profile (administrative)" - ) + scheduler = TabPanel(child=scheduler, title="Scheduler Profile (administrative)") task_stream = TabPanel(child=task_stream, title="Task Stream") - bandwidth_workers = TabPanel( - child=bandwidth_workers.root, title="Bandwidth (Workers)" - ) - bandwidth_types = TabPanel( - child=bandwidth_types.root, title="Bandwidth (Types)" - ) + bandwidth_workers = TabPanel(child=bandwidth_workers.root, title="Bandwidth (Workers)") + bandwidth_types = TabPanel(child=bandwidth_types.root, title="Bandwidth (Types)") system = TabPanel(child=sysmon.root, title="System") logs = TabPanel(child=logs.root, title="Scheduler Logs") @@ -8620,9 +8253,7 @@ def profile_to_figure(state: object) -> object: with tmpfile(extension=".html") as fn: output_file(filename=fn, title="Dask Performance Report", mode=mode) - template_directory = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" - ) + template_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates") template_environment = get_env() template_environment.loader.searchpath.append(template_directory) template = template_environment.get_template("performance_report.html") @@ -8633,12 +8264,8 @@ def profile_to_figure(state: object) -> object: return data - async def get_worker_logs( - self, n: int | None = None, workers: list | None = None, nanny: bool = False - ) -> dict: - results = await self.broadcast( - msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny - ) + async def get_worker_logs(self, n: int | None = None, workers: list | None = None, nanny: bool = False) -> dict: + results = await self.broadcast(msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny) return results def log_event(self, topic: str | Collection[str], msg: Any) -> None: @@ -8675,16 +8302,11 @@ def get_events( ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: return self._broker.get_events(topic) - async def get_worker_monitor_info( - self, recent: bool = False, starts: dict | None = None - ) -> dict: + async def get_worker_monitor_info(self, recent: bool = False, starts: dict | None = None) -> dict: if starts is None: starts = {} results = await asyncio.gather( - *( - self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) - for w in self.workers - ) + *(self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) for w in self.workers) ) return dict(zip(self.workers, results)) @@ -8737,11 +8359,7 @@ def check_idle(self) -> float | None: self.idle_since = None return None - if ( - self.queued - or self.unrunnable - or any(ws.processing for ws in self.workers.values()) - ): + if self.queued or self.unrunnable or any(ws.processing for ws in self.workers.values()): self.idle_since = None return None @@ -8750,9 +8368,7 @@ def check_idle(self) -> float | None: return self.idle_since if self.jupyter: - last_activity = ( - self._jupyter_server_application.web_app.last_activity().timestamp() - ) + last_activity = self._jupyter_server_application.web_app.last_activity().timestamp() if last_activity > self.idle_since: self.idle_since = last_activity return self.idle_since @@ -8764,16 +8380,11 @@ def check_idle(self) -> float | None: "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self._ongoing_background_tasks.call_soon( - self.close, reason="idle-timeout-exceeded" - ) + self._ongoing_background_tasks.call_soon(self.close, reason="idle-timeout-exceeded") return self.idle_since def _check_no_workers(self) -> None: - if ( - self.status in (Status.closing, Status.closed) - or self.no_workers_timeout is None - ): + if self.status in (Status.closing, Status.closed) or self.no_workers_timeout is None: return now = monotonic() @@ -8783,15 +8394,9 @@ def _check_no_workers(self) -> None: self._refresh_no_workers_since(now) - affected = self._check_unrunnable_task_timeouts( - now, recommendations=recommendations, stimulus_id=stimulus_id - ) + affected = self._check_unrunnable_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id) - affected.update( - self._check_queued_task_timeouts( - now, recommendations=recommendations, stimulus_id=stimulus_id - ) - ) + affected.update(self._check_queued_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id)) self.transitions(recommendations, stimulus_id=stimulus_id) if affected: self.log_event( @@ -8799,9 +8404,7 @@ def _check_no_workers(self) -> None: {"action": "no-workers-timeout-exceeded", "keys": affected}, ) - def _check_unrunnable_task_timeouts( - self, timestamp: float, recommendations: Recs, stimulus_id: str - ) -> set[Key]: + def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: assert self.no_workers_timeout unsatisfied = [] no_workers = [] @@ -8810,10 +8413,7 @@ def _check_unrunnable_task_timeouts( # unrunnable is insertion-ordered, which means that unrunnable_since will # be monotonically increasing in this loop. break - if ( - self._no_workers_since is None - or self._no_workers_since >= unrunnable_since - ): + if self._no_workers_since is None or self._no_workers_since >= unrunnable_since: unsatisfied.append(ts) else: no_workers.append(ts) @@ -8839,18 +8439,13 @@ def _check_unrunnable_task_timeouts( ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting " - "for its restrictions to become satisfied.", + "Task %s marked as failed because it timed out waiting for its restrictions to become satisfied.", ts.key, ) - self._fail_tasks_after_no_workers_timeout( - no_workers, recommendations, stimulus_id - ) + self._fail_tasks_after_no_workers_timeout(no_workers, recommendations, stimulus_id) return {ts.key for ts in concat([unsatisfied, no_workers])} - def _check_queued_task_timeouts( - self, timestamp: float, recommendations: Recs, stimulus_id: str - ) -> set[Key]: + def _check_queued_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: assert self.no_workers_timeout if self._no_workers_since is None: @@ -8859,9 +8454,7 @@ def _check_queued_task_timeouts( if timestamp <= self._no_workers_since + self.no_workers_timeout: return set() affected = list(self.queued) - self._fail_tasks_after_no_workers_timeout( - affected, recommendations, stimulus_id - ) + self._fail_tasks_after_no_workers_timeout(affected, recommendations, stimulus_id) return {ts.key for ts in affected} def _fail_tasks_after_no_workers_timeout( @@ -8885,8 +8478,7 @@ def _fail_tasks_after_no_workers_timeout( ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting " - "without any running workers.", + "Task %s marked as failed because it timed out waiting without any running workers.", ts.key, ) @@ -8961,9 +8553,7 @@ def adaptive_target(self, target_duration: float | None = None) -> int: to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas( - self, addr: str, keys: Iterable[Key], *, stimulus_id: str - ) -> None: + def request_acquire_replicas(self, addr: str, keys: Iterable[Key], *, stimulus_id: str) -> None: """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. @@ -8985,9 +8575,7 @@ def request_acquire_replicas( }, ) - def request_remove_replicas( - self, addr: str, keys: list[Key], *, stimulus_id: str - ) -> None: + def request_remove_replicas(self, addr: str, keys: list[Key], *, stimulus_id: str) -> None: """Asynchronously ask a worker to discard its replica of the listed keys. This must never be used to destroy the last replica of a key. This is a fire-and-forget operation, intended for housekeeping and not for computation. @@ -9171,13 +8759,7 @@ def validate_task_state(ts: TaskState) -> None: if ts.run_spec: # was computed assert ts.type assert isinstance(ts.type, str) - assert not any( - [ - ts in dts.waiting_on - for dts in ts.dependents - if dts.waiting_on is not None - ] - ) + assert not any([ts in dts.waiting_on for dts in ts.dependents if dts.waiting_on is not None]) for ws in ts.who_has: assert ts in ws.has_what, ( "not in who_has' has_what", @@ -9280,9 +8862,7 @@ def heartbeat_interval(n: int) -> float: def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int: """Number of tasks that can be sent to this worker without oversaturating it""" assert not math.isinf(saturation_factor) - return max(math.ceil(saturation_factor * ws.nthreads), 1) - ( - len(ws.processing) - len(ws.long_running) - ) + return max(math.ceil(saturation_factor * ws.nthreads), 1) - (len(ws.processing) - len(ws.long_running)) def _worker_full(ws: WorkerState, saturation_factor: float) -> bool: @@ -9326,9 +8906,7 @@ def __init__( resource_restrictions: dict[str, float], timeout: float, ): - super().__init__( - task, host_restrictions, worker_restrictions, resource_restrictions, timeout - ) + super().__init__(task, host_restrictions, worker_restrictions, resource_restrictions, timeout) @property def task(self) -> Key: From 8996969bf2a786b543e69d391074cfa0129cc122 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:02:06 -0800 Subject: [PATCH 2/7] Update --- distributed/tests/test_condition.py | 418 ++++++++++++++++++++++++++++ 1 file changed, 418 insertions(+) create mode 100644 distributed/tests/test_condition.py diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py new file mode 100644 index 0000000000..ea8e051796 --- /dev/null +++ b/distributed/tests/test_condition.py @@ -0,0 +1,418 @@ +import asyncio +import pytest + +from distributed import Condition, Client, wait +from distributed.utils_test import gen_cluster, inc +from distributed.metrics import time + + +@gen_cluster(client=True) +async def test_condition_acqui re_release(c, s, a, b): + """Test basic lock acquire/release""" + condition = Condition("test-lock") + + assert not condition.locked() + await condition.acquire() + assert condition.locked() + await condition.release() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_context_manager(c, s, a, b): + """Test context manager interface""" + condition = Condition("test-context") + + assert not condition.locked() + async with condition: + assert condition.locked() + assert not condition.locked() + + +@gen_cluster(client=True) +async def test_condition_wait_notify(c, s, a, b): + """Test basic wait/notify""" + condition = Condition("test-notify") + results = [] + + async def waiter(): + async with condition: + results.append("waiting") + await condition.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + results.append("notifying") + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_notify_all(c, s, a, b): + """Test notify_all wakes all waiters""" + condition = Condition("test-notify-all") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify_all() + + await asyncio.gather( + waiter(1), waiter(2), waiter(3), notifier() + ) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_notify_n(c, s, a, b): + """Test notify with specific count""" + condition = Condition("test-notify-n") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify(n=2) # Wake only 2 waiters + await asyncio.sleep(0.2) + async with condition: + condition.notify() # Wake remaining waiter + + await asyncio.gather( + waiter(1), waiter(2), waiter(3), notifier() + ) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_wait_timeout(c, s, a, b): + """Test wait with timeout""" + condition = Condition("test-timeout") + + start = time() + async with condition: + result = await condition.wait(timeout=0.5) + elapsed = time() - start + + assert result is False + assert 0.4 < elapsed < 0.7 + + +@gen_cluster(client=True) +async def test_condition_wait_timeout_then_notify(c, s, a, b): + """Test that timeout doesn't prevent subsequent notifications""" + condition = Condition("test-timeout-notify") + results = [] + + async def waiter(): + async with condition: + result = await condition.wait(timeout=0.2) + results.append(f"timeout: {result}") + + async with condition: + result = await condition.wait() + results.append(f"notified: {result}") + + async def notifier(): + await asyncio.sleep(0.5) + async with condition: + condition.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["timeout: False", "notified: True"] + + +@gen_cluster(client=True) +async def test_condition_error_without_lock(c, s, a, b): + """Test errors when calling wait/notify without holding lock""" + condition = Condition("test-error") + + with pytest.raises(RuntimeError, match="without holding the lock"): + await condition.wait() + + with pytest.raises(RuntimeError, match="Cannot notify"): + await condition.notify() + + with pytest.raises(RuntimeError, match="Cannot notify"): + await condition.notify_all() + + +@gen_cluster(client=True) +async def test_condition_error_release_without_acquire(c, s, a, b): + """Test error when releasing without acquiring""" + condition = Condition("test-release-error") + + with pytest.raises(RuntimeError, match="Cannot release"): + await condition.release() + + +@gen_cluster(client=True) +async def test_condition_producer_consumer(c, s, a, b): + """Test classic producer-consumer pattern""" + condition = Condition("prod-cons") + queue = [] + + async def producer(): + for i in range(5): + await asyncio.sleep(0.1) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(5): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + prod_task = asyncio.create_task(producer()) + cons_task = asyncio.create_task(consumer()) + + await prod_task + results = await cons_task + + assert results == [0, 1, 2, 3, 4] + + +@gen_cluster(client=True) +async def test_condition_multiple_producers_consumers(c, s, a, b): + """Test multiple producers and consumers""" + condition = Condition("multi-prod-cons") + queue = [] + + async def producer(start): + for i in range(start, start + 3): + await asyncio.sleep(0.05) + async with condition: + queue.append(i) + condition.notify() + + async def consumer(): + results = [] + for _ in range(3): + async with condition: + while not queue: + await condition.wait() + results.append(queue.pop(0)) + return results + + results = await asyncio.gather( + producer(0), producer(10), + consumer(), consumer() + ) + + # Last two results are from consumers + consumed = results[2] + results[3] + assert sorted(consumed) == [0, 1, 2, 10, 11, 12] + + +@gen_cluster(client=True) +async def test_condition_from_worker(c, s, a, b): + """Test condition accessed from worker tasks""" + def wait_on_condition(name): + from distributed import Condition + import asyncio + + async def _wait(): + condition = Condition(name) + async with condition: + await condition.wait() + return "worker_notified" + + from distributed.worker import get_worker + worker = get_worker() + return worker.loop.run_until_complete(_wait()) + + def notify_condition(name): + from distributed import Condition + import asyncio + + async def _notify(): + await asyncio.sleep(0.2) + condition = Condition(name) + async with condition: + condition.notify() + return "notified" + + from distributed.worker import get_worker + worker = get_worker() + return worker.loop.run_until_complete(_notify()) + + name = "worker-condition" + f1 = c.submit(wait_on_condition, name, workers=[a.address]) + f2 = c.submit(notify_condition, name, workers=[b.address]) + + results = await c.gather([f1, f2]) + assert results == ["worker_notified", "notified"] + + +@gen_cluster(client=True) +async def test_condition_same_name_different_instances(c, s, a, b): + """Test that multiple instances with same name share state""" + name = "shared-condition" + cond1 = Condition(name) + cond2 = Condition(name) + + results = [] + + async def waiter(): + async with cond1: + results.append("waiting") + await cond1.wait() + results.append("notified") + + async def notifier(): + await asyncio.sleep(0.2) + async with cond2: + results.append("notifying") + cond2.notify() + + await asyncio.gather(waiter(), notifier()) + assert results == ["waiting", "notifying", "notified"] + + +@gen_cluster(client=True) +async def test_condition_unique_names_independent(c, s, a, b): + """Test conditions with different names are independent""" + cond1 = Condition("cond-1") + cond2 = Condition("cond-2") + + async with cond1: + assert cond1.locked() + assert not cond2.locked() + + async with cond2: + assert not cond1.locked() + assert cond2.locked() + + +@gen_cluster(client=True) +async def test_condition_cleanup(c, s, a, b): + """Test that condition state is cleaned up after use""" + condition = Condition("cleanup-test") + + # Check initial state + assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + assert "cleanup-test" not in s.extensions["conditions"]._waiters + + # Use condition + async with condition: + condition.notify() + + # State should be cleaned up + await asyncio.sleep(0.1) + assert "cleanup-test" not in s.extensions["conditions"]._lock_holders + + +@gen_cluster(client=True) +async def test_condition_barrier_pattern(c, s, a, b): + """Test barrier synchronization pattern""" + condition = Condition("barrier") + arrived = [] + n_workers = 3 + + async def worker(i): + async with condition: + arrived.append(i) + if len(arrived) < n_workers: + await condition.wait() + else: + condition.notify_all() + return f"worker-{i}-done" + + results = await asyncio.gather( + worker(0), worker(1), worker(2) + ) + + assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] + assert len(arrived) == 3 + + +def test_condition_sync_interface(client): + """Test synchronous interface via SyncMethodMixin""" + condition = Condition("sync-test") + results = [] + + def worker(): + with condition: + results.append("locked") + results.append("released") + + worker() + assert results == ["locked", "released"] + + +@gen_cluster(client=True) +async def test_condition_multiple_notify_calls(c, s, a, b): + """Test multiple notify calls in sequence""" + condition = Condition("multi-notify") + results = [] + + async def waiter(i): + async with condition: + await condition.wait() + results.append(i) + + async def notifier(): + await asyncio.sleep(0.2) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + await asyncio.sleep(0.1) + async with condition: + condition.notify() + + await asyncio.gather( + waiter(1), waiter(2), waiter(3), notifier() + ) + assert sorted(results) == [1, 2, 3] + + +@gen_cluster(client=True) +async def test_condition_predicate_loop(c, s, a, b): + """Test typical predicate-based wait loop pattern""" + condition = Condition("predicate") + state = {"value": 0, "target": 5} + + async def waiter(): + async with condition: + while state["value"] < state["target"]: + await condition.wait() + return state["value"] + + async def updater(): + for i in range(1, 6): + await asyncio.sleep(0.1) + async with condition: + state["value"] = i + condition.notify_all() + + result, _ = await asyncio.gather(waiter(), updater()) + assert result == 5 + + +@gen_cluster(client=True) +async def test_condition_repr(c, s, a, b): + """Test string representation""" + condition = Condition("test-repr") + assert "test-repr" in repr(condition) + assert "Condition" in repr(condition) From d505ca0f434fea7976dca67f5ec7367062ad1d7d Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:10:36 -0800 Subject: [PATCH 3/7] Update scheduler.py --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e9b4bec324..590a84a924 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -193,8 +193,8 @@ "variables": VariableExtension, "semaphores": SemaphoreExtension, "events": EventExtension, - "conditions": ConditionExtension, "amm": ActiveMemoryManagerExtension, + "conditions": ConditionExtension, "memory_sampler": MemorySamplerExtension, "shuffle": ShuffleSchedulerPlugin, "spans": SpansSchedulerExtension, From 47a790f819ce48d071472e0688b67ac9234ad494 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 18:12:41 -0800 Subject: [PATCH 4/7] Update scheduler.py --- distributed/scheduler.py | 824 +++++++++++++++++++++++++++++---------- 1 file changed, 624 insertions(+), 200 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 590a84a924..f2ad5080a8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -182,7 +182,9 @@ logger = logging.getLogger(__name__) LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -DEFAULT_DATA_SIZE = parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) +DEFAULT_DATA_SIZE = parse_bytes( + dask.config.get("distributed.scheduler.default-data-size") +) STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { @@ -407,7 +409,8 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return { k: getattr(self, k) for k in dir(self) - if not k.startswith("_") and k not in {"sum", "managed_in_memory", "managed_spilled"} + if not k.startswith("_") + and k not in {"sum", "managed_in_memory", "managed_spilled"} } @@ -579,7 +582,9 @@ def __hash__(self) -> int: return self._hash def __eq__(self, other: object) -> bool: - return self is other or (isinstance(other, WorkerState) and other.server_id == self.server_id) + return self is other or ( + isinstance(other, WorkerState) and other.server_id == self.server_id + ) @property def has_what(self) -> Set[TaskState]: @@ -830,7 +835,9 @@ def _dec_needs_replica(self, ts: TaskState) -> None: nbytes = ts.get_nbytes() # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) + self.scheduler._network_occ_global -= min( + nbytes, self.scheduler._network_occ_global + ) def add_replica(self, ts: TaskState) -> None: """The worker acquired a replica of task""" @@ -843,14 +850,18 @@ def add_replica(self, ts: TaskState) -> None: del self.needs_what[ts] # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift self._network_occ -= min(nbytes, self._network_occ) - self.scheduler._network_occ_global -= min(nbytes, self.scheduler._network_occ_global) + self.scheduler._network_occ_global -= min( + nbytes, self.scheduler._network_occ_global + ) ts.who_has.add(self) self.nbytes += nbytes self._has_what[ts] = None @property def occupancy(self) -> float: - return self._occupancy_cache or self.scheduler._calc_occupancy(self.task_prefix_count, self._network_occ) + return self._occupancy_cache or self.scheduler._calc_occupancy( + self.task_prefix_count, self._network_occ + ) @dataclasses.dataclass @@ -912,7 +923,9 @@ def __repr__(self) -> str: return ( f"" ) @@ -970,7 +983,10 @@ def all_durations(self) -> defaultdict[str, float]: """Cumulative duration of all completed actions of tasks belonging to this collection, by action""" return defaultdict( float, - {action: duration_us / 1e6 for action, duration_us in self._all_durations_us.items()}, + { + action: duration_us / 1e6 + for action, duration_us in self._all_durations_us.items() + }, ) @property @@ -1075,7 +1091,13 @@ def active_states(self) -> dict[TaskStateState, int]: def __repr__(self) -> str: return ( - "<" + self.name + ": " + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + ">" + "<" + + self.name + + ": " + + ", ".join( + "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + ) + + ">" ) @@ -1165,7 +1187,9 @@ def __repr__(self) -> str: "<" + (self.name or "no-group") + ": " - + ", ".join("%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v) + + ", ".join( + "%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v + ) + ">" ) @@ -1193,7 +1217,8 @@ def done(self) -> bool: recomputed. """ return all( - count == 0 or state in {"memory", "erred", "released", "forgotten"} for state, count in self.states.items() + count == 0 or state in {"memory", "erred", "released", "forgotten"} + for state, count in self.states.items() ) @@ -1752,9 +1777,15 @@ def __init__( self.resources = resources self.saturated = set() self.tasks = tasks - self.replicated_tasks = {ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1} - self.computations = deque(maxlen=dask.config.get("distributed.diagnostics.computations.max-history")) - self.erred_tasks = deque(maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history")) + self.replicated_tasks = { + ts for ts in self.tasks.values() if len(ts.who_has or ()) > 1 + } + self.computations = deque( + maxlen=dask.config.get("distributed.diagnostics.computations.max-history") + ) + self.erred_tasks = deque( + maxlen=dask.config.get("distributed.diagnostics.erred-tasks.max-history") + ) self.task_groups = {} self.task_prefixes = {} self.task_metadata = {} @@ -1767,38 +1798,61 @@ def __init__( self.workers = workers self._task_prefix_count_global = defaultdict(int) self._network_occ_global = 0 - self.running = {ws for ws in self.workers.values() if ws.status == Status.running} + self.running = { + ws for ws in self.workers.values() if ws.status == Status.running + } self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} - self.transition_log = deque(maxlen=dask.config.get("distributed.admin.low-level-log-length")) + self.transition_log = deque( + maxlen=dask.config.get("distributed.admin.low-level-log-length") + ) self.transition_counter = 0 self._idle_transition_counter = 0 self.transition_counter_max = transition_counter_max # Variables from dask.config, cached by __init__ for performance - self.UNKNOWN_TASK_DURATION = parse_timedelta(dask.config.get("distributed.scheduler.unknown-task-duration")) + self.UNKNOWN_TASK_DURATION = parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration") + ) self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( dask.config.get("distributed.worker.memory.recent-to-old-time") ) - self.MEMORY_REBALANCE_MEASURE = dask.config.get("distributed.worker.memory.rebalance.measure") - self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get("distributed.worker.memory.rebalance.sender-min") - self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get("distributed.worker.memory.rebalance.recipient-max") + self.MEMORY_REBALANCE_MEASURE = dask.config.get( + "distributed.worker.memory.rebalance.measure" + ) + self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( + "distributed.worker.memory.rebalance.sender-min" + ) + self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( + "distributed.worker.memory.rebalance.recipient-max" + ) self.MEMORY_REBALANCE_HALF_GAP = ( - dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 + dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") + / 2.0 ) - self.WORKER_SATURATION = dask.config.get("distributed.scheduler.worker-saturation") + self.WORKER_SATURATION = dask.config.get( + "distributed.scheduler.worker-saturation" + ) if self.WORKER_SATURATION == "inf": # Special case necessary because there's no way to parse a float infinity # from a DASK_* environment variable self.WORKER_SATURATION = math.inf - if not isinstance(self.WORKER_SATURATION, (int, float)) or self.WORKER_SATURATION <= 0: + if ( + not isinstance(self.WORKER_SATURATION, (int, float)) + or self.WORKER_SATURATION <= 0 + ): raise ValueError( # pragma: nocover - "`distributed.scheduler.worker-saturation` must be a float > 0; got " + repr(self.WORKER_SATURATION) + "`distributed.scheduler.worker-saturation` must be a float > 0; got " + + repr(self.WORKER_SATURATION) ) - self.rootish_tg_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup") - self.rootish_tg_dependencies_threshold = dask.config.get("distributed.scheduler.rootish-taskgroup-dependencies") + self.rootish_tg_threshold = dask.config.get( + "distributed.scheduler.rootish-taskgroup" + ) + self.rootish_tg_dependencies_threshold = dask.config.get( + "distributed.scheduler.rootish-taskgroup-dependencies" + ) @abstractmethod def log_event(self, topic: str | Collection[str], msg: Any) -> None: ... @@ -1932,7 +1986,9 @@ def _calc_occupancy( # State Transitions # ##################### - def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any) -> RecsMsgs: + def _transition( + self, key: Key, finish: TaskStateState, stimulus_id: str, **kwargs: Any + ) -> RecsMsgs: """Transition a key from its current state to the finish state Examples @@ -1978,11 +2034,15 @@ def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwar func = self._TRANSITIONS_TABLE.get((start, finish)) if func is not None: - recommendations, client_msgs, worker_msgs = func(self, key, stimulus_id, **kwargs) + recommendations, client_msgs, worker_msgs = func( + self, key, stimulus_id, **kwargs + ) elif "released" not in (start, finish): assert not kwargs, (kwargs, start, finish) - a_recs, a_cmsgs, a_wmsgs = self._transition(key, "released", stimulus_id) + a_recs, a_cmsgs, a_wmsgs = self._transition( + key, "released", stimulus_id + ) v = a_recs.get(key, finish) # The inner rec has higher priority? Is that always desired? @@ -2012,10 +2072,16 @@ def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwar stimulus_id = STIMULUS_ID_UNSET actual_finish = ts._state - self.transition_log.append(Transition(key, start, actual_finish, recommendations, stimulus_id, time())) + self.transition_log.append( + Transition( + key, start, actual_finish, recommendations, stimulus_id, time() + ) + ) if self.validate: if stimulus_id == STIMULUS_ID_UNSET: - raise RuntimeError("stimulus_id not set during Scheduler transition") + raise RuntimeError( + "stimulus_id not set during Scheduler transition" + ) logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2032,7 +2098,9 @@ def _transition(self, key: Key, finish: TaskStateState, stimulus_id: str, **kwar self.tasks[ts.key] = ts for plugin in list(self.plugins.values()): try: - plugin.transition(key, start, actual_finish, stimulus_id=stimulus_id, **kwargs) + plugin.transition( + key, start, actual_finish, stimulus_id=stimulus_id, **kwargs + ) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": @@ -2216,7 +2284,9 @@ def _transition_queued_erred( traceback_text=traceback_text, ) - def decide_worker_rootish_queuing_disabled(self, ts: TaskState) -> WorkerState | None: + def decide_worker_rootish_queuing_disabled( + self, ts: TaskState + ) -> WorkerState | None: """Pick a worker for a runnable root-ish task, without queuing. This attempts to schedule sibling tasks on the same worker, reducing future data @@ -2247,16 +2317,25 @@ def decide_worker_rootish_queuing_disabled(self, ts: TaskState) -> WorkerState | tg = ts.group lws = tg.last_worker - if lws and tg.last_worker_tasks_left and lws.status == Status.running and self.workers.get(lws.address) is lws: + if ( + lws + and tg.last_worker_tasks_left + and lws.status == Status.running + and self.workers.get(lws.address) is lws + ): ws = lws else: # Last-used worker is full, unknown, retiring, or paused; # pick a new worker for the next few tasks ws = min(pool, key=partial(self.worker_objective, ts)) - tg.last_worker_tasks_left = math.floor((len(tg) / self.total_nthreads) * ws.nthreads) + tg.last_worker_tasks_left = math.floor( + (len(tg) / self.total_nthreads) * ws.nthreads + ) # Record `last_worker`, or clear it on the final task - tg.last_worker = ws if tg.states["released"] + tg.states["waiting"] > 1 else None + tg.last_worker = ( + ws if tg.states["released"] + tg.states["waiting"] > 1 else None + ) tg.last_worker_tasks_left -= 1 if self.validate and ws is not None: @@ -2497,7 +2576,9 @@ def _transition_processing_memory( recommendations: Recs = {} client_msgs: Msgs = {} - self._add_to_memory(ts, ws, recommendations, client_msgs, type=type, typename=typename) + self._add_to_memory( + ts, ws, recommendations, client_msgs, type=type, typename=typename + ) if self.validate: assert not ts.processing_on @@ -2517,7 +2598,9 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: ws.actors.discard(ts) if ts.who_wants: ts.exception_blame = ts - ts.exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) + ts.exception = Serialized( + *serialize(RuntimeError("Worker holding Actor was lost")) + ) return {ts.key: "erred"}, {}, {} # don't try to recreate recommendations: Recs = {} @@ -2544,7 +2627,9 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: recommendations[key] = "forgotten" elif ts.has_lost_dependencies: recommendations[key] = "forgotten" - elif (ts.who_wants or ts.waiters) and not any(dts.state == "erred" for dts in ts.dependencies): + elif (ts.who_wants or ts.waiters) and not any( + dts.state == "erred" for dts in ts.dependencies + ): recommendations[key] = "waiting" for dts in ts.waiters or (): @@ -2924,7 +3009,9 @@ def _transition_memory_erred(self, key: Key, stimulus_id: str) -> RecsMsgs: if not dts.who_has: dts.exception_blame = ts recommendations[dts.key] = "erred" - exception = Serialized(*serialize(RuntimeError("Worker holding Actor was lost"))) + exception = Serialized( + *serialize(RuntimeError("Worker holding Actor was lost")) + ) report_msg = { "op": "task-erred", "key": key, @@ -3037,9 +3124,14 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs ("released", "erred"): _transition_released_erred, } - def story(self, *keys_or_tasks_or_stimuli: Key | TaskState | str) -> list[Transition]: + def story( + self, *keys_or_tasks_or_stimuli: Key | TaskState | str + ) -> list[Transition]: """Get all transitions that touch one of the input keys or stimulus_id's""" - keys_or_stimuli = {key.key if isinstance(key, TaskState) else key for key in keys_or_tasks_or_stimuli} + keys_or_stimuli = { + key.key if isinstance(key, TaskState) else key + for key in keys_or_tasks_or_stimuli + } return scheduler_story(keys_or_stimuli, self.transition_log) ############################## @@ -3115,9 +3207,14 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: else: self.idle_task_count.discard(ws) - def is_unoccupied(self, ws: WorkerState, occupancy: float, nprocessing: int) -> bool: + def is_unoccupied( + self, ws: WorkerState, occupancy: float, nprocessing: int + ) -> bool: nthreads = ws.nthreads - return nprocessing < nthreads or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 + return ( + nprocessing < nthreads + or occupancy < nthreads * (self.total_occupancy / self.total_nthreads) / 2 + ) def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: """ @@ -3293,7 +3390,9 @@ def _validate_ready(self, ts: TaskState) -> None: assert ts not in self.queued assert all(dts.who_has for dts in ts.dependencies) - def _add_to_processing(self, ts: TaskState, ws: WorkerState, stimulus_id: str) -> RecsMsgs: + def _add_to_processing( + self, ts: TaskState, ws: WorkerState, stimulus_id: str + ) -> RecsMsgs: """Set a task as processing on a worker and return the worker messages to send""" if self.validate: self._validate_ready(ts) @@ -3311,7 +3410,11 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState, stimulus_id: str) - ws.actors.add(ts) ndep_bytes = sum(dts.nbytes for dts in ts.dependencies) - if ws.memory_limit and ndep_bytes > ws.memory_limit and dask.config.get("distributed.worker.memory.terminate"): + if ( + ws.memory_limit + and ndep_bytes > ws.memory_limit + and dask.config.get("distributed.worker.memory.terminate") + ): # Note # ---- # This is a crude safety system, only meant to prevent order-of-magnitude @@ -3519,7 +3622,10 @@ def _task_to_msg(self, ts: TaskState) -> dict[str, Any]: "run_id": ts.run_id, "priority": ts.priority, "stimulus_id": f"compute-task-{time()}", - "who_has": {dts.key: tuple(ws.address for ws in (dts.who_has or ())) for dts in ts.dependencies}, + "who_has": { + dts.key: tuple(ws.address for ws in (dts.who_has or ())) + for dts in ts.dependencies + }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, "run_spec": ToPickle(ts.run_spec), "resource_restrictions": ts.resource_restrictions, @@ -3691,10 +3797,16 @@ def __init__( self.services = {} self.scheduler_file = scheduler_file - self.worker_ttl = parse_timedelta(worker_ttl or dask.config.get("distributed.scheduler.worker-ttl")) - self.idle_timeout = parse_timedelta(idle_timeout or dask.config.get("distributed.scheduler.idle-timeout")) + self.worker_ttl = parse_timedelta( + worker_ttl or dask.config.get("distributed.scheduler.worker-ttl") + ) + self.idle_timeout = parse_timedelta( + idle_timeout or dask.config.get("distributed.scheduler.idle-timeout") + ) self.idle_since = time() - self.no_workers_timeout = parse_timedelta(dask.config.get("distributed.scheduler.no-workers-timeout")) + self.no_workers_timeout = parse_timedelta( + dask.config.get("distributed.scheduler.no-workers-timeout") + ) self._no_workers_since = None self.time_started = self.idle_since # compatibility for dask-gateway @@ -3742,17 +3854,24 @@ def __init__( except ImportError: show_dashboard = False http_server_modules.append("distributed.http.scheduler.missing_bokeh") - routes = get_handlers(server=self, modules=http_server_modules, prefix=http_prefix) + routes = get_handlers( + server=self, modules=http_server_modules, prefix=http_prefix + ) self.start_http_server(routes, dashboard_address, default_port=8787) self.jupyter = jupyter if show_dashboard: - distributed.dashboard.scheduler.connect(self.http_application, self.http_server, self, prefix=http_prefix) + distributed.dashboard.scheduler.connect( + self.http_application, self.http_server, self, prefix=http_prefix + ) scheduler = self if self.jupyter: try: from jupyter_server.serverapp import ServerApp except ImportError: - raise ImportError("In order to use the Dask jupyter option you need to have jupyterlab installed") + raise ImportError( + "In order to use the Dask jupyter option you " + "need to have jupyterlab installed" + ) from traitlets.config import Config """HTTP handler to shut down the Jupyter server. @@ -3800,7 +3919,9 @@ async def post(self) -> None: argv=[], ) self._jupyter_server_application = j - shutdown_app = tornado.web.Application([(r"/jupyter/api/shutdown", ShutdownHandler)]) + shutdown_app = tornado.web.Application( + [(r"/jupyter/api/shutdown", ShutdownHandler)] + ) shutdown_app.settings = j.web_app.settings self.http_application.add_application(shutdown_app) self.http_application.add_application(j.web_app) @@ -4020,7 +4141,8 @@ def identity(self, n_workers: int = -1) -> dict[str, Any]: "total_threads": self.total_nthreads, "total_memory": self.total_memory, "workers": { - worker.address: worker.identity() for worker in itertools.islice(self.workers.values(), n_workers) + worker.address: worker.identity() + for worker in itertools.islice(self.workers.values(), n_workers) }, } return d @@ -4078,7 +4200,10 @@ async def get_cluster_state( workers_future.cancel() # Convert any RPC errors to strings - worker_states = {k: repr(v) if isinstance(v, Exception) else v for k, v in worker_states.items()} + worker_states = { + k: repr(v) if isinstance(v, Exception) else v + for k, v in worker_states.items() + } return { "scheduler": scheduler_state, @@ -4094,7 +4219,9 @@ async def dump_cluster_state_to_url( **storage_options: dict[str, Any], ) -> None: "Write a cluster state dump to an fsspec-compatible URL." - await cluster_dump.write_state(partial(self.get_cluster_state, exclude), url, format, **storage_options) + await cluster_dump.write_state( + partial(self.get_cluster_state, exclude), url, format, **storage_options + ) def get_worker_service_addr( self, worker: str, service_name: str, protocol: bool = False @@ -4162,7 +4289,9 @@ async def start_unsafe(self) -> Self: # formatting dashboard link can fail if distributed.dashboard.link # refers to non-existent env vars. except KeyError as e: - logger.warning(f"Failed to format dashboard link, unknown value: {e}") + logger.warning( + f"Failed to format dashboard link, unknown value: {e}" + ) link = f":{server.port}" else: link = f"{listen_ip}:{server.port}" @@ -4188,7 +4317,9 @@ def del_scheduler_file() -> None: await self.listen("tcp://localhost:0") os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address - await asyncio.gather(*[plugin.start(self) for plugin in list(self.plugins.values())]) + await asyncio.gather( + *[plugin.start(self) for plugin in list(self.plugins.values())] + ) self.start_periodic_callbacks() @@ -4220,11 +4351,15 @@ async def log_errors(func: Callable) -> None: except Exception: logger.exception("Plugin call failed during scheduler.close") - await asyncio.gather(*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]) + await asyncio.gather( + *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] + ) await self.preloads.teardown() - await asyncio.gather(*[log_errors(plugin.close) for plugin in list(self.plugins.values())]) + await asyncio.gather( + *[log_errors(plugin.close) for plugin in list(self.plugins.values())] + ) for pc in self.periodic_callbacks.values(): pc.stop() @@ -4297,21 +4432,25 @@ def heartbeat_worker( dh["last-seen"] = local_now frac = 1 / len(self.workers) - self.bandwidth = self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + self.bandwidth = ( + self.bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + ) for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: self.bandwidth_workers[address, other] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_workers[address, other] = self.bandwidth_workers[address, other] * alpha + bw * ( - 1 - alpha - ) + self.bandwidth_workers[address, other] = self.bandwidth_workers[ + address, other + ] * alpha + bw * (1 - alpha) for typ, (bw, count) in metrics["bandwidth"]["types"].items(): if typ not in self.bandwidth_types: self.bandwidth_types[typ] = bw / count else: alpha = (1 - frac) ** count - self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * (1 - alpha) + self.bandwidth_types[typ] = self.bandwidth_types[typ] * alpha + bw * ( + 1 - alpha + ) ws.last_seen = local_now if executing is not None: @@ -4342,7 +4481,9 @@ def heartbeat_worker( # ws._nbytes is updated at a different time and sizeof() may not be accurate, # so size may be (temporarily) negative; floor it to zero. - size = max(0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"]) + size = max( + 0, metrics["memory"] - ws.nbytes + metrics["spilled_bytes"]["memory"] + ) ws._memory_unmanaged_history.append((local_now, size)) if not memory_unmanaged_old: @@ -4487,7 +4628,9 @@ async def add_worker( logger.exception(exc, exc_info=exc) if ws.status == Status.running: - self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) logger.info("Register worker addr: %s name: %s", ws.address, ws.name) @@ -4671,7 +4814,9 @@ def _create_taskstate_from_graph( # _generate_taskstates is not the only thing that calls new_task(). A # TaskState may have also been created by client_desires_keys or scatter, # and only later gained a run_spec. - span_annotations = spans_ext.observe_tasks(touched_tasks, span_metadata=span_metadata, code=code) + span_annotations = spans_ext.observe_tasks( + touched_tasks, span_metadata=span_metadata, code=code + ) # In case of TaskGroup collision, spans may have changed # FIXME: Is this used anywhere besides tests? if span_annotations: @@ -4777,7 +4922,9 @@ async def update_graph( }, client=client, ) - self.client_releases_keys(keys=keys, client=client, stimulus_id=stimulus_id) + self.client_releases_keys( + keys=keys, client=client, stimulus_id=stimulus_id + ) evt_msg = { "action": "update-graph", "stimulus_id": stimulus_id, @@ -4810,7 +4957,8 @@ async def update_graph( "start_timestamp_seconds": start, "materialization_duration_seconds": materialization_done - start, "ordering_duration_seconds": materialization_done - ordering_done, - "state_initialization_duration_seconds": ordering_done - task_state_created, + "state_initialization_duration_seconds": ordering_done + - task_state_created, "duration_seconds": task_state_created - start, } ) @@ -5063,7 +5211,9 @@ def _set_priorities( ) if self.validate and istask(ts.run_spec): - assert isinstance(ts.priority, tuple) and all(isinstance(el, (int, float)) for el in ts.priority) + assert isinstance(ts.priority, tuple) and all( + isinstance(el, (int, float)) for el in ts.priority + ) def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools @@ -5082,7 +5232,10 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """ if not self.queued: return - slots_available = sum(_task_slots_available(ws, self.WORKER_SATURATION) for ws in self.idle_task_count) + slots_available = sum( + _task_slots_available(ws, self.WORKER_SATURATION) + for ws in self.idle_task_count + ) if slots_available == 0: return @@ -5104,7 +5257,9 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: assert qts.state == "processing" assert not self.queued or self.queued.peek() != qts - def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any) -> RecsMsgs: + def stimulus_task_finished( + self, key: Key, worker: str, stimulus_id: str, run_id: int, **kwargs: Any + ) -> RecsMsgs: """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s[%d] %s", key, run_id, worker) @@ -5115,7 +5270,8 @@ def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id ts = self.tasks.get(key) if ts is None or ts.state in ("released", "queued", "no-worker"): logger.debug( - "Received already computed task, worker: %s, state: %s, key: %s, who_has: %s", + "Received already computed task, worker: %s, state: %s" + ", key: %s, who_has: %s", worker, ts.state if ts else "forgotten", key, @@ -5130,7 +5286,7 @@ def stimulus_task_finished(self, key: Key, worker: str, stimulus_id: str, run_id ] elif ts.state == "erred": logger.debug( - "Received already erred task, worker: %s, key: %s", + "Received already erred task, worker: %s" ", key: %s", worker, key, ) @@ -5207,7 +5363,9 @@ def stimulus_task_erred( **kwargs, ) - def stimulus_retry(self, keys: Collection[Key], client: str | None = None) -> tuple[Key, ...]: + def stimulus_retry( + self, keys: Collection[Key], client: str | None = None + ) -> tuple[Key, ...]: logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -5285,10 +5443,14 @@ async def remove_worker( ws = self.workers[address] - logger.info(f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})") + logger.info( + f"Remove worker addr: {ws.address} name: {ws.name} ({stimulus_id=})" + ) if close: with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send({"op": "close", "reason": "scheduler-remove-worker"}) + self.stream_comms[address].send( + {"op": "close", "reason": "scheduler-remove-worker"} + ) self.remove_resources(address) @@ -5342,7 +5504,8 @@ async def remove_worker( ) recommendations.update(r) logger.error( - "Task %s marked as failed because %d workers died while trying to run it", + "Task %s marked as failed because %d workers died" + " while trying to run it", ts.key, ts.suspicious, ) @@ -5395,7 +5558,9 @@ async def remove_worker( for plugin in list(self.plugins.values()): try: try: - result = plugin.remove_worker(scheduler=self, worker=address, stimulus_id=stimulus_id) + result = plugin.remove_worker( + scheduler=self, worker=address, stimulus_id=stimulus_id + ) except TypeError: parameters = inspect.signature(plugin.remove_worker).parameters if "stimulus_id" not in parameters and not any( @@ -5427,9 +5592,13 @@ async def remove_worker_from_events() -> None: if address not in self.workers: self._broker.truncate(address) - cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) - self._ongoing_background_tasks.call_later(cleanup_delay, remove_worker_from_events) + self._ongoing_background_tasks.call_later( + cleanup_delay, remove_worker_from_events + ) logger.debug("Removed worker %s", ws) for w in self.workers: @@ -5444,7 +5613,9 @@ async def remove_worker_from_events() -> None: return "OK" - def stimulus_cancel(self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str) -> None: + def stimulus_cancel( + self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str + ) -> None: """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) self.log_event(client, {"action": "cancel", "count": len(keys), "force": force}) @@ -5506,7 +5677,9 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None: if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys: Collection[Key], client: str, stimulus_id: str | None = None) -> None: + def client_releases_keys( + self, keys: Collection[Key], client: str, stimulus_id: str | None = None + ) -> None: """Remove keys from client desired list""" stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): @@ -5565,7 +5738,9 @@ def validate_queued(self, key: Key) -> None: assert not ts.waiting_on assert not ts.who_has assert not ts.processing_on - assert not (ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions) + assert not ( + ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions + ) for dts in ts.dependencies: assert dts.who_has assert ts in (dts.waiters or ()) @@ -5591,7 +5766,9 @@ def validate_memory(self, key: Key) -> None: assert ts not in self.unrunnable assert ts not in self.queued for dts in ts.dependents: - assert (dts in (ts.waiters or ())) == (dts.state in ("waiting", "queued", "processing", "no-worker")) + assert (dts in (ts.waiters or ())) == ( + dts.state in ("waiting", "queued", "processing", "no-worker") + ) assert ts not in (dts.waiting_on or ()) def validate_no_worker(self, key: Key) -> None: @@ -5622,7 +5799,9 @@ def validate_key(self, key: Key, ts: TaskState | None = None) -> None: try: func = getattr(self, "validate_" + ts.state.replace("-", "_")) except AttributeError: - logger.error("self.validate_%s not found", ts.state.replace("-", "_")) + logger.error( + "self.validate_%s not found", ts.state.replace("-", "_") + ) else: func(key) except Exception as e: @@ -5688,9 +5867,9 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert task_prefix_counts.keys() == self._task_prefix_count_global.keys() for name, global_count in self._task_prefix_count_global.items(): - assert task_prefix_counts[name] == global_count, ( - f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" - ) + assert ( + task_prefix_counts[name] == global_count + ), f"{name}: {task_prefix_counts[name]} (wss), {global_count} (global)" for ws in self.running: assert ws.status == Status.running @@ -5713,7 +5892,10 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert cs.client_key == c a = {w: ws.nbytes for w, ws in self.workers.items()} - b = {w: sum(ts.get_nbytes() for ts in ws.has_what) for w, ws in self.workers.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws.has_what) + for w, ws in self.workers.items() + } assert a == b, (a, b) if self.transition_counter_max: @@ -5723,7 +5905,9 @@ def validate_state(self, allow_overlap: bool = False) -> None: # Manage Messages # ################### - def report(self, msg: dict, ts: TaskState | None = None, client: str | None = None) -> None: + def report( + self, msg: dict, ts: TaskState | None = None, client: str | None = None + ) -> None: """ Publish updates to all listening Queues and Comms @@ -5745,7 +5929,9 @@ def report(self, msg: dict, ts: TaskState | None = None, client: str | None = No # Notify clients interested in key (including `client`) # Note that, if report() was called by update_graph(), `client` won't be in # ts.who_wants yet. - client_keys = [cs.client_key for cs in ts.who_wants or () if cs.client_key != client] + client_keys = [ + cs.client_key for cs in ts.who_wants or () if cs.client_key != client + ] if client is not None: client_keys.append(client) @@ -5758,9 +5944,13 @@ def report(self, msg: dict, ts: TaskState | None = None, client: str | None = No # logger.debug("Scheduler sends message to client %s: %s", k, msg) except CommClosedError: if self.status == Status.running: - logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) + logger.critical( + "Closed comm %r while trying to write %s", c, msg, exc_info=True + ) - async def add_client(self, comm: Comm, client: str, versions: dict[str, Any]) -> None: + async def add_client( + self, comm: Comm, client: str, versions: dict[str, Any] + ) -> None: """Add client to network We listen to all future messages from this Comm. @@ -5838,9 +6028,13 @@ async def remove_client_from_events() -> None: if client not in self.clients: self._broker.truncate(client) - cleanup_delay = parse_timedelta(dask.config.get("distributed.scheduler.events-cleanup-delay")) + cleanup_delay = parse_timedelta( + dask.config.get("distributed.scheduler.events-cleanup-delay") + ) if not self._ongoing_background_tasks.closed: - self._ongoing_background_tasks.call_later(cleanup_delay, remove_client_from_events) + self._ongoing_background_tasks.call_later( + cleanup_delay, remove_client_from_events + ) def send_task_to_worker(self, worker: str, ts: TaskState) -> None: """Send a single computational task to a worker""" @@ -5858,13 +6052,17 @@ def send_task_to_worker(self, worker: str, ts: TaskState) -> None: def handle_uncaught_error(self, **msg: Any) -> None: logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key: Key, worker: str, stimulus_id: str, **msg: Any) -> None: + def handle_task_finished( + self, key: Key, worker: str, stimulus_id: str, **msg: Any + ) -> None: if worker not in self.workers: return if self.validate: self.validate_key(key) - r: tuple = self.stimulus_task_finished(key=key, worker=worker, stimulus_id=stimulus_id, **msg) + r: tuple = self.stimulus_task_finished( + key=key, worker=worker, stimulus_id=stimulus_id, **msg + ) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) @@ -5903,7 +6101,9 @@ def handle_long_running( duration accounting as if the task has stopped. """ if worker not in self.workers: - logger.debug("Received long-running signal from unknown worker %s. Ignoring.", worker) + logger.debug( + "Received long-running signal from unknown worker %s. Ignoring.", worker + ) return if key not in self.tasks: @@ -5941,7 +6141,9 @@ def handle_long_running( self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) - def handle_worker_status_change(self, status: str | Status, worker: str | WorkerState, stimulus_id: str) -> None: + def handle_worker_status_change( + self, status: str | Status, worker: str | WorkerState, stimulus_id: str + ) -> None: ws = self.workers.get(worker) if isinstance(worker, str) else worker if not ws: return @@ -5964,7 +6166,9 @@ def handle_worker_status_change(self, status: str | Status, worker: str | Worker if ws.status == Status.running: self.running.add(ws) self.check_idle_saturated(ws) - self.transitions(self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) @@ -5973,7 +6177,9 @@ def handle_worker_status_change(self, status: str | Status, worker: str | Worker self.saturated.discard(ws) self._refresh_no_workers_since() - def handle_request_refresh_who_has(self, keys: Iterable[Key], worker: str, stimulus_id: str) -> None: + def handle_request_refresh_who_has( + self, keys: Iterable[Key], worker: str, stimulus_id: str + ) -> None: """Request from a Worker to refresh the who_has for some keys. Not to be confused with scheduler.who_has, which is a dedicated comm RPC request from a Client. @@ -6022,7 +6228,9 @@ async def handle_worker(self, comm: Comm, worker: str) -> None: finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker(worker, stimulus_id=f"handle-worker-cleanup-{time()}") + await self.remove_worker( + worker, stimulus_id=f"handle-worker-cleanup-{time()}" + ) def add_plugin( self, @@ -6083,7 +6291,9 @@ def remove_plugin(self, name: str | None = None) -> None: try: del self.plugins[name] except KeyError: - raise ValueError(f"Could not find plugin {name!r} among the current scheduler plugins") + raise ValueError( + f"Could not find plugin {name!r} among the current scheduler plugins" + ) async def register_scheduler_plugin( self, @@ -6145,7 +6355,9 @@ def client_send(self, client: str, msg: dict) -> None: c.send(msg) except CommClosedError: if self.status == Status.running: - logger.critical("Closed comm %r while trying to write %s", c, msg, exc_info=True) + logger.critical( + "Closed comm %r while trying to write %s", c, msg, exc_info=True + ) def send_all(self, client_msgs: Msgs, worker_msgs: Msgs) -> None: """Send messages to client and workers""" @@ -6223,10 +6435,14 @@ async def scatter( n = len(workers) if broadcast is True else broadcast await self.replicate(keys=keys, workers=workers, n=n) - self.log_event([client, "all"], {"action": "scatter", "client": client, "count": len(data)}) + self.log_event( + [client, "all"], {"action": "scatter", "client": client, "count": len(data)} + ) return keys - async def gather(self, keys: Collection[Key], serializers: list[str] | None = None) -> dict[Key, object]: + async def gather( + self, keys: Collection[Key], serializers: list[str] | None = None + ) -> dict[Key, object]: """Collect data from workers to the scheduler""" data = {} missing_keys = list(keys) @@ -6247,7 +6463,9 @@ async def gather(self, keys: Collection[Key], serializers: list[str] | None = No missing_keys, new_failed_keys, new_missing_workers, - ) = await gather_from_workers(who_has, rpc=self.rpc, serializers=serializers) + ) = await gather_from_workers( + who_has, rpc=self.rpc, serializers=serializers + ) data.update(new_data) failed_keys += new_failed_keys missing_workers.update(new_missing_workers) @@ -6257,7 +6475,10 @@ async def gather(self, keys: Collection[Key], serializers: list[str] | None = No if not failed_keys: return {"status": "OK", "data": data} - failed_states = {key: self.tasks[key].state if key in self.tasks else "forgotten" for key in failed_keys} + failed_states = { + key: self.tasks[key].state if key in self.tasks else "forgotten" + for key in failed_keys + } logger.error("Couldn't gather keys: %s", failed_states) return {"status": "error", "keys": list(failed_keys)} @@ -6367,16 +6588,24 @@ async def restart_workers( workers = list(set(workers).intersection(self.workers)) logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}") - nanny_workers = {addr: self.workers[addr].nanny for addr in workers if self.workers[addr].nanny} + nanny_workers = { + addr: self.workers[addr].nanny + for addr in workers + if self.workers[addr].nanny + } # Close non-Nanny workers. We have no way to restart them, so we just let them # go, and assume a deployment system is going to restart them for us. no_nanny_workers = [addr for addr in workers if addr not in nanny_workers] if no_nanny_workers: logger.warning( - f"Workers {no_nanny_workers} do not use a nanny and will be terminated without restarting them" + f"Workers {no_nanny_workers} do not use a nanny and will be terminated " + "without restarting them" ) await asyncio.gather( - *(self.remove_worker(address=addr, stimulus_id=stimulus_id) for addr in no_nanny_workers) + *( + self.remove_worker(address=addr, stimulus_id=stimulus_id) + for addr in no_nanny_workers + ) ) out: dict[str, Literal["OK", "removed", "timed out"]] out = {addr: "removed" for addr in no_nanny_workers} @@ -6386,7 +6615,9 @@ async def restart_workers( async with contextlib.AsyncExitStack() as stack: nannies = await asyncio.gather( *( - stack.enter_async_context(rpc(nanny_address, connection_args=self.connection_args)) + stack.enter_async_context( + rpc(nanny_address, connection_args=self.connection_args) + ) for nanny_address in nanny_workers.values() ) ) @@ -6422,8 +6653,16 @@ async def restart_workers( raise resp if bad_nannies: - logger.error(f"Workers {list(bad_nannies)} did not shut down within {timeout}s; force closing") - await asyncio.gather(*(self.remove_worker(addr, stimulus_id=stimulus_id) for addr in bad_nannies)) + logger.error( + f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " + "force closing" + ) + await asyncio.gather( + *( + self.remove_worker(addr, stimulus_id=stimulus_id) + for addr in bad_nannies + ) + ) if on_error == "raise": raise TimeoutError( f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not " @@ -6432,10 +6671,15 @@ async def restart_workers( if client: self.log_event(client, {"action": "restart-workers", "workers": workers}) - self.log_event("all", {"action": "restart-workers", "workers": workers, "client": client}) + self.log_event( + "all", {"action": "restart-workers", "workers": workers, "client": client} + ) if not wait_for_workers: - logger.info(f"Workers restart finished (did not wait for new workers) ({stimulus_id=}") + logger.info( + "Workers restart finished (did not wait for new workers) " + f"({stimulus_id=}" + ) return out # NOTE: if new (unrelated) workers join while we're waiting, we may return @@ -6504,7 +6748,9 @@ async def broadcast( ERROR = object() - reuse_broadcast_comm = dask.config.get("distributed.scheduler.reuse-broadcast-comm", False) + reuse_broadcast_comm = dask.config.get( + "distributed.scheduler.reuse-broadcast-comm", False + ) close = not reuse_broadcast_comm async def send_message(addr: str) -> Any: @@ -6512,7 +6758,9 @@ async def send_message(addr: str) -> Any: comm = await self.rpc.connect(addr) comm.name = "Scheduler Broadcast" try: - resp = await send_recv(comm, close=close, serializers=serializers, **msg) + resp = await send_recv( + comm, close=close, serializers=serializers, **msg + ) finally: self.rpc.reuse(addr, comm) return resp @@ -6528,7 +6776,8 @@ async def send_message(addr: str) -> Any: return ERROR else: raise ValueError( - f"on_error must be 'raise', 'return', 'return_pickle', or 'ignore'; got {on_error!r}" + "on_error must be 'raise', 'return', 'return_pickle', " + f"or 'ignore'; got {on_error!r}" ) results = await All([send_message(address) for address in addresses]) @@ -6544,7 +6793,9 @@ async def proxy( d = await self.broadcast(msg=msg, workers=[worker], serializers=serializers) return d[worker] - async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[str]]) -> set: + async def gather_on_worker( + self, worker_address: str, who_has: dict[Key, list[str]] + ) -> set: """Peer-to-peer copy of keys from multiple workers to a single worker Parameters @@ -6560,12 +6811,15 @@ async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[st set of keys that failed to be copied """ try: - result = await retry_operation(self.rpc(addr=worker_address).gather, who_has=who_has) + result = await retry_operation( + self.rpc(addr=worker_address).gather, who_has=who_has + ) except OSError as e: # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during " + f"replication: {e.__class__.__name__}: {e}" ) return set(who_has) @@ -6580,7 +6834,9 @@ async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[st elif result["status"] == "partial-fail": keys_failed = set(result["keys"]) keys_ok = who_has.keys() - keys_failed - logger.warning(f"Worker {worker_address} failed to acquire keys: {result['keys']}") + logger.warning( + f"Worker {worker_address} failed to acquire keys: {result['keys']}" + ) else: # pragma: nocover raise ValueError(f"Unexpected message from {worker_address}: {result}") @@ -6593,7 +6849,9 @@ async def gather_on_worker(self, worker_address: str, who_has: dict[Key, list[st return keys_failed - async def delete_worker_data(self, worker_address: str, keys: Collection[Key], stimulus_id: str) -> None: + async def delete_worker_data( + self, worker_address: str, keys: Collection[Key], stimulus_id: str + ) -> None: """Delete data from a worker and update the corresponding worker/task states Parameters @@ -6613,7 +6871,8 @@ async def delete_worker_data(self, worker_address: str, keys: Collection[Key], s # This can happen e.g. if the worker is going through controlled shutdown; # it doesn't necessarily mean that it went unexpectedly missing logger.warning( - f"Communication with worker {worker_address} failed during replication: {e.__class__.__name__}: {e}" + f"Communication with worker {worker_address} failed during " + f"replication: {e.__class__.__name__}: {e}" ) return @@ -6719,7 +6978,9 @@ async def rebalance( keys = set(keys) # unless already a set-like if not keys: return {"status": "OK"} - missing_data = [k for k in keys if k not in self.tasks or not self.tasks[k].who_has] + missing_data = [ + k for k in keys if k not in self.tasks or not self.tasks[k].who_has + ] if missing_data: return {"status": "partial-fail", "keys": missing_data} @@ -6794,7 +7055,9 @@ def _rebalance_find_msgs( # unmanaged memory that appeared over the last 30 seconds # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. - memory_by_worker = [(ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers] + memory_by_worker = [ + (ws, getattr(ws.memory, self.MEMORY_REBALANCE_MEASURE)) for ws in workers + ] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: @@ -6807,12 +7070,18 @@ def _rebalance_find_msgs( sender_min = 0.0 recipient_max = math.inf - if ws._has_what and ws_memory >= mean_memory + half_gap and ws_memory >= sender_min: + if ( + ws._has_what + and ws_memory >= mean_memory + half_gap + and ws_memory >= sender_min + ): # This may send the worker below sender_min (by design) snd_bytes_max = mean_memory - ws_memory # negative snd_bytes_min = snd_bytes_max + half_gap # negative # See definition of senders above - senders.append((snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what))) + senders.append( + (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what)) + ) elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max: # This may send the worker above recipient_max (by design) rec_bytes_max = ws_memory - mean_memory # negative @@ -6930,7 +7199,9 @@ async def _rebalance_move_data( FIXME this method is not robust when the cluster is not idle. """ # {recipient address: {key: [sender address, ...]}} - to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict(lambda: defaultdict(list)) + to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict( + lambda: defaultdict(list) + ) for snd_ws, rec_ws, ts in msgs: to_recipients[rec_ws.address][ts.key].append(snd_ws.address) failed_keys_by_recipient = dict( @@ -6952,7 +7223,9 @@ async def _rebalance_move_data( to_senders[snd_ws.address].append(ts.key) # Note: this never raises exceptions - await asyncio.gather(*(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())) + await asyncio.gather( + *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items()) + ) for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) @@ -7031,13 +7304,17 @@ async def replicate( assert ts.who_has is not None del_candidates = tuple(ts.who_has & workers) if len(del_candidates) > n: - for ws in random.sample(del_candidates, len(del_candidates) - n): + for ws in random.sample( + del_candidates, len(del_candidates) - n + ): del_worker_tasks[ws].add(ts) # Note: this never raises exceptions await asyncio.gather( *[ - self.delete_worker_data(ws.address, [t.key for t in tasks], stimulus_id) + self.delete_worker_data( + ws.address, [t.key for t in tasks], stimulus_id + ) for ws, tasks in del_worker_tasks.items() ] ) @@ -7062,7 +7339,9 @@ async def replicate( assert count > 0 for ws in random.sample(tuple(workers - ts.who_has), count): - gathers[ws.address][ts.key] = [wws.address for wws in ts.who_has] + gathers[ws.address][ts.key] = [ + wws.address for wws in ts.who_has + ] await asyncio.gather( *( @@ -7318,7 +7597,8 @@ async def retire_workers( raise TypeError("names and workers are mutually exclusive") if (names is not None or workers is not None) and kwargs: raise TypeError( - f"Parameters for workers_to_close() are mutually exclusive with names and workers: {kwargs}" + "Parameters for workers_to_close() are mutually exclusive with " + f"names and workers: {kwargs}" ) stimulus_id = stimulus_id or f"retire-workers-{time()}" @@ -7338,16 +7618,24 @@ async def retire_workers( stimulus_id, workers, ) - wss = {self.workers[address] for address in workers if address in self.workers} + wss = { + self.workers[address] + for address in workers + if address in self.workers + } else: - wss = {self.workers[address] for address in self.workers_to_close(**kwargs)} + wss = { + self.workers[address] for address in self.workers_to_close(**kwargs) + } if not wss: return {} stop_amm = False amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm") if not amm or not amm.running: - amm = ActiveMemoryManagerExtension(self, policies=set(), register=False, start=True, interval=2.0) + amm = ActiveMemoryManagerExtension( + self, policies=set(), register=False, start=True, interval=2.0 + ) stop_amm = True try: @@ -7359,7 +7647,9 @@ async def retire_workers( # Change Worker.status to closing_gracefully. Immediately set # the same on the scheduler to prevent race conditions. prev_status = ws.status - self.handle_worker_status_change(Status.closing_gracefully, ws, stimulus_id) + self.handle_worker_status_change( + Status.closing_gracefully, ws, stimulus_id + ) # FIXME: We should send a message to the nanny first; # eventually workers won't be able to close their own nannies. self.stream_comms[ws.address].send( @@ -7450,10 +7740,14 @@ async def _track_retire_worker( ) return ws.address, "no-recipients", ws.identity() - logger.debug(f"All unique keys on worker {ws.address!r} have been replicated elsewhere") + logger.debug( + f"All unique keys on worker {ws.address!r} have been replicated elsewhere" + ) if remove: - await self.remove_worker(ws.address, expected=True, close=close, stimulus_id=stimulus_id) + await self.remove_worker( + ws.address, expected=True, close=close, stimulus_id=stimulus_id + ) elif close: self.close_worker(ws.address) @@ -7595,7 +7889,9 @@ async def feed( if teardown: teardown(self, state) # type: ignore - def log_worker_event(self, worker: str, topic: str | Collection[str], msg: Any) -> None: + def log_worker_event( + self, worker: str, topic: str | Collection[str], msg: Any + ) -> None: if isinstance(msg, dict) and worker != topic: msg["worker"] = worker self.log_event(topic, msg) @@ -7608,25 +7904,46 @@ def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]: del v["last_seen"] return ident - def get_processing(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: + def get_processing( + self, workers: Iterable[str] | None = None + ) -> dict[str, list[Key]]: if workers is not None: workers = set(map(self.coerce_address, workers)) return {w: [ts.key for ts in self.workers[w].processing] for w in workers} else: - return {w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()} + return { + w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() + } def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]: if keys is not None: return { - key: ([ws.address for ws in self.tasks[key].who_has or ()] if key in self.tasks else []) for key in keys + key: ( + [ws.address for ws in self.tasks[key].who_has or ()] + if key in self.tasks + else [] + ) + for key in keys } else: - return {key: [ws.address for ws in ts.who_has or ()] for key, ts in self.tasks.items()} + return { + key: [ws.address for ws in ts.who_has or ()] + for key, ts in self.tasks.items() + } - def get_has_what(self, workers: Iterable[str] | None = None) -> dict[str, list[Key]]: + def get_has_what( + self, workers: Iterable[str] | None = None + ) -> dict[str, list[Key]]: if workers is not None: workers = map(self.coerce_address, workers) - return {w: ([ts.key for ts in self.workers[w].has_what] if w in self.workers else []) for w in workers} + return { + w: ( + [ts.key for ts in self.workers[w].has_what] + if w in self.workers + else [] + ) + for w in workers + } else: return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()} @@ -7637,9 +7954,13 @@ def get_ncores(self, workers: Iterable[str] | None = None) -> dict[str, int]: else: return {w: ws.nthreads for w, ws in self.workers.items()} - def get_ncores_running(self, workers: Iterable[str] | None = None) -> dict[str, int]: + def get_ncores_running( + self, workers: Iterable[str] | None = None + ) -> dict[str, int]: ncores = self.get_ncores(workers=workers) - return {w: n for w, n in ncores.items() if self.workers[w].status == Status.running} + return { + w: n for w, n in ncores.items() if self.workers[w].status == Status.running + } async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, Any]: workers: dict[str, list[Key] | None] @@ -7666,7 +7987,9 @@ async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, A if not workers: return {} - results = await asyncio.gather(*(self.rpc(w).call_stack(keys=v) for w, v in workers.items())) + results = await asyncio.gather( + *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + ) response = {w: r for w, r in zip(workers, results) if r} return response @@ -7707,7 +8030,9 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: # implementing logic based on IP addresses would not necessarily help. # Randomize the connections to even out the mean measures. random.shuffle(workers) - futures = [self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers)] + futures = [ + self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers) + ] responses = await asyncio.gather(*futures) for d in responses: @@ -7716,12 +8041,17 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: result = {} for mode in out: - result[mode] = {size: sum(durations) / len(durations) for size, durations in out[mode].items()} + result[mode] = { + size: sum(durations) / len(durations) + for size, durations in out[mode].items() + } return result @log_errors - def get_nbytes(self, keys: Iterable[Key] | None = None, summary: bool = True) -> dict[Key, int]: + def get_nbytes( + self, keys: Iterable[Key] | None = None, summary: bool = True + ) -> dict[Key, int]: if keys is not None: result = {k: self.tasks[k].nbytes for k in keys} else: @@ -7807,7 +8137,9 @@ def get_task_prefix_states(self) -> dict[str, dict[str, int]]: return state def get_task_status(self, keys: Iterable[Key]) -> dict[Key, TaskStateState | None]: - return {key: (self.tasks[key].state if key in self.tasks else None) for key in keys} + return { + key: (self.tasks[key].state if key in self.tasks else None) for key in keys + } def get_task_stream( self, @@ -7830,11 +8162,14 @@ def start_task_metadata(self, name: str) -> None: def stop_task_metadata(self, name: str | None = None) -> dict: plugins = [ - p for p in list(self.plugins.values()) if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name + p + for p in list(self.plugins.values()) + if isinstance(p, CollectTaskMetaDataPlugin) and p.name == name ] if len(plugins) != 1: raise ValueError( - f"Expected to find exactly one CollectTaskMetaDataPlugin with name {name} but found {len(plugins)}." + "Expected to find exactly one CollectTaskMetaDataPlugin " + f"with name {name} but found {len(plugins)}." ) plugin = plugins[0] @@ -7859,10 +8194,14 @@ async def register_worker_plugin( self.worker_plugins[name] = plugin - responses = await self.broadcast(msg=dict(op="plugin-add", plugin=plugin, name=name)) + responses = await self.broadcast( + msg=dict(op="plugin-add", plugin=plugin, name=name) + ) return responses - async def unregister_worker_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_worker_plugin( + self, comm: None, name: str + ) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.worker_plugins.pop(name) @@ -7894,21 +8233,27 @@ async def register_nanny_plugin( async with self._starting_nannies_cond: if self._starting_nannies: logger.info("Waiting for Nannies to start %s", self._starting_nannies) - await self._starting_nannies_cond.wait_for(lambda: not self._starting_nannies) + await self._starting_nannies_cond.wait_for( + lambda: not self._starting_nannies + ) responses = await self.broadcast( msg=dict(op="plugin_add", plugin=plugin, name=name), nanny=True, ) return responses - async def unregister_nanny_plugin(self, comm: None, name: str) -> dict[str, ErrorMessage | OKMessage]: + async def unregister_nanny_plugin( + self, comm: None, name: str + ) -> dict[str, ErrorMessage | OKMessage]: """Unregisters a worker plugin""" try: self.nanny_plugins.pop(name) except KeyError: raise ValueError(f"The nanny plugin {name} does not exist") - responses = await self.broadcast(msg=dict(op="plugin_remove", name=name), nanny=True) + responses = await self.broadcast( + msg=dict(op="plugin_remove", name=name), nanny=True + ) return responses def transition( @@ -7933,7 +8278,9 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recommendations, client_msgs, worker_msgs = self._transition(key, finish, stimulus_id, **kwargs) + recommendations, client_msgs, worker_msgs = self._transition( + key, finish, stimulus_id, **kwargs + ) self.send_all(client_msgs, worker_msgs) return recommendations @@ -7956,7 +8303,9 @@ async def get_story(self, keys_or_stimuli: Iterable[Key | str]) -> list[Transiti """ return self.story(*keys_or_stimuli) - def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) -> None: + def _reschedule( + self, key: Key, worker: str | None = None, *, stimulus_id: str + ) -> None: """Reschedule a task. This function should only be used when the task has already been released in @@ -7968,7 +8317,8 @@ def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) ts = self.tasks[key] except KeyError: logger.warning( - f"Attempting to reschedule task {key!r}, which was not found on the scheduler. Aborting reschedule." + f"Attempting to reschedule task {key!r}, which was not " + "found on the scheduler. Aborting reschedule." ) return if ts.state != "processing": @@ -7983,7 +8333,9 @@ def _reschedule(self, key: Key, worker: str | None = None, *, stimulus_id: str) # Utility functions # ##################### - def add_resources(self, worker: str, resources: dict | None = None) -> Literal["OK"]: + def add_resources( + self, worker: str, resources: dict | None = None + ) -> Literal["OK"]: ws = self.workers[worker] if resources: ws.resources.update(resources) @@ -8065,7 +8417,10 @@ async def get_profile( ) results = await asyncio.gather( - *(self.rpc(w).profile(start=start, stop=stop, key=key, server=server) for w in workers), + *( + self.rpc(w).profile(start=start, stop=stop, key=key, server=server) + for w in workers + ), return_exceptions=True, ) @@ -8085,7 +8440,9 @@ async def get_profile_metadata( stop: float | None = None, profile_cycle_interval: str | float | None = None, ) -> dict[str, Any]: - dt = profile_cycle_interval or dask.config.get("distributed.worker.profile.cycle") + dt = profile_cycle_interval or dask.config.get( + "distributed.worker.profile.cycle" + ) dt = parse_timedelta(dt, default="ms") if workers is None: @@ -8108,7 +8465,9 @@ async def get_profile_metadata( ) ] - keys: dict[Key, list[list]] = {k: [] for v in results for t, d in v["keys"] for k in d} + keys: dict[Key, list[list]] = { + k: [] for v in results for t, d in v["keys"] for k in d + } groups1 = [v["keys"] for v in results] groups2 = list(merge_sorted(*groups1, key=first)) @@ -8125,7 +8484,9 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report(self, start: float, last_count: int, code: str = "", mode: str | None = None) -> str: + async def performance_report( + self, start: float, last_count: int, code: str = "", mode: str | None = None + ) -> str: stop = time() # Profiles compute_d, scheduler_d, workers_d = await asyncio.gather( @@ -8142,7 +8503,9 @@ def profile_to_figure(state: object) -> object: figure, source = profile.plot_figure(data, sizing_mode="stretch_both") return figure - compute, scheduler, workers = map(profile_to_figure, (compute_d, scheduler_d, workers_d)) + compute, scheduler, workers = map( + profile_to_figure, (compute_d, scheduler_d, workers_d) + ) del compute_d, scheduler_d, workers_d # Task stream @@ -8226,10 +8589,16 @@ def profile_to_figure(state: object) -> object: html = TabPanel(child=html, title="Summary") compute = TabPanel(child=compute, title="Worker Profile (compute)") workers = TabPanel(child=workers, title="Worker Profile (administrative)") - scheduler = TabPanel(child=scheduler, title="Scheduler Profile (administrative)") + scheduler = TabPanel( + child=scheduler, title="Scheduler Profile (administrative)" + ) task_stream = TabPanel(child=task_stream, title="Task Stream") - bandwidth_workers = TabPanel(child=bandwidth_workers.root, title="Bandwidth (Workers)") - bandwidth_types = TabPanel(child=bandwidth_types.root, title="Bandwidth (Types)") + bandwidth_workers = TabPanel( + child=bandwidth_workers.root, title="Bandwidth (Workers)" + ) + bandwidth_types = TabPanel( + child=bandwidth_types.root, title="Bandwidth (Types)" + ) system = TabPanel(child=sysmon.root, title="System") logs = TabPanel(child=logs.root, title="Scheduler Logs") @@ -8253,7 +8622,9 @@ def profile_to_figure(state: object) -> object: with tmpfile(extension=".html") as fn: output_file(filename=fn, title="Dask Performance Report", mode=mode) - template_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates") + template_directory = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" + ) template_environment = get_env() template_environment.loader.searchpath.append(template_directory) template = template_environment.get_template("performance_report.html") @@ -8264,8 +8635,12 @@ def profile_to_figure(state: object) -> object: return data - async def get_worker_logs(self, n: int | None = None, workers: list | None = None, nanny: bool = False) -> dict: - results = await self.broadcast(msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny) + async def get_worker_logs( + self, n: int | None = None, workers: list | None = None, nanny: bool = False + ) -> dict: + results = await self.broadcast( + msg={"op": "get_logs", "n": n}, workers=workers, nanny=nanny + ) return results def log_event(self, topic: str | Collection[str], msg: Any) -> None: @@ -8302,11 +8677,16 @@ def get_events( ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: return self._broker.get_events(topic) - async def get_worker_monitor_info(self, recent: bool = False, starts: dict | None = None) -> dict: + async def get_worker_monitor_info( + self, recent: bool = False, starts: dict | None = None + ) -> dict: if starts is None: starts = {} results = await asyncio.gather( - *(self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) for w in self.workers) + *( + self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) + for w in self.workers + ) ) return dict(zip(self.workers, results)) @@ -8359,7 +8739,11 @@ def check_idle(self) -> float | None: self.idle_since = None return None - if self.queued or self.unrunnable or any(ws.processing for ws in self.workers.values()): + if ( + self.queued + or self.unrunnable + or any(ws.processing for ws in self.workers.values()) + ): self.idle_since = None return None @@ -8368,7 +8752,9 @@ def check_idle(self) -> float | None: return self.idle_since if self.jupyter: - last_activity = self._jupyter_server_application.web_app.last_activity().timestamp() + last_activity = ( + self._jupyter_server_application.web_app.last_activity().timestamp() + ) if last_activity > self.idle_since: self.idle_since = last_activity return self.idle_since @@ -8380,11 +8766,16 @@ def check_idle(self) -> float | None: "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self._ongoing_background_tasks.call_soon(self.close, reason="idle-timeout-exceeded") + self._ongoing_background_tasks.call_soon( + self.close, reason="idle-timeout-exceeded" + ) return self.idle_since def _check_no_workers(self) -> None: - if self.status in (Status.closing, Status.closed) or self.no_workers_timeout is None: + if ( + self.status in (Status.closing, Status.closed) + or self.no_workers_timeout is None + ): return now = monotonic() @@ -8394,9 +8785,15 @@ def _check_no_workers(self) -> None: self._refresh_no_workers_since(now) - affected = self._check_unrunnable_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id) + affected = self._check_unrunnable_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id + ) - affected.update(self._check_queued_task_timeouts(now, recommendations=recommendations, stimulus_id=stimulus_id)) + affected.update( + self._check_queued_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id + ) + ) self.transitions(recommendations, stimulus_id=stimulus_id) if affected: self.log_event( @@ -8404,7 +8801,9 @@ def _check_no_workers(self) -> None: {"action": "no-workers-timeout-exceeded", "keys": affected}, ) - def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: + def _check_unrunnable_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: assert self.no_workers_timeout unsatisfied = [] no_workers = [] @@ -8413,7 +8812,10 @@ def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Rec # unrunnable is insertion-ordered, which means that unrunnable_since will # be monotonically increasing in this loop. break - if self._no_workers_since is None or self._no_workers_since >= unrunnable_since: + if ( + self._no_workers_since is None + or self._no_workers_since >= unrunnable_since + ): unsatisfied.append(ts) else: no_workers.append(ts) @@ -8439,13 +8841,18 @@ def _check_unrunnable_task_timeouts(self, timestamp: float, recommendations: Rec ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting for its restrictions to become satisfied.", + "Task %s marked as failed because it timed out waiting " + "for its restrictions to become satisfied.", ts.key, ) - self._fail_tasks_after_no_workers_timeout(no_workers, recommendations, stimulus_id) + self._fail_tasks_after_no_workers_timeout( + no_workers, recommendations, stimulus_id + ) return {ts.key for ts in concat([unsatisfied, no_workers])} - def _check_queued_task_timeouts(self, timestamp: float, recommendations: Recs, stimulus_id: str) -> set[Key]: + def _check_queued_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: assert self.no_workers_timeout if self._no_workers_since is None: @@ -8454,7 +8861,9 @@ def _check_queued_task_timeouts(self, timestamp: float, recommendations: Recs, s if timestamp <= self._no_workers_since + self.no_workers_timeout: return set() affected = list(self.queued) - self._fail_tasks_after_no_workers_timeout(affected, recommendations, stimulus_id) + self._fail_tasks_after_no_workers_timeout( + affected, recommendations, stimulus_id + ) return {ts.key for ts in affected} def _fail_tasks_after_no_workers_timeout( @@ -8478,7 +8887,8 @@ def _fail_tasks_after_no_workers_timeout( ) recommendations.update(r) logger.error( - "Task %s marked as failed because it timed out waiting without any running workers.", + "Task %s marked as failed because it timed out waiting " + "without any running workers.", ts.key, ) @@ -8553,7 +8963,9 @@ def adaptive_target(self, target_duration: float | None = None) -> int: to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas(self, addr: str, keys: Iterable[Key], *, stimulus_id: str) -> None: + def request_acquire_replicas( + self, addr: str, keys: Iterable[Key], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. @@ -8575,7 +8987,9 @@ def request_acquire_replicas(self, addr: str, keys: Iterable[Key], *, stimulus_i }, ) - def request_remove_replicas(self, addr: str, keys: list[Key], *, stimulus_id: str) -> None: + def request_remove_replicas( + self, addr: str, keys: list[Key], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to discard its replica of the listed keys. This must never be used to destroy the last replica of a key. This is a fire-and-forget operation, intended for housekeeping and not for computation. @@ -8759,7 +9173,13 @@ def validate_task_state(ts: TaskState) -> None: if ts.run_spec: # was computed assert ts.type assert isinstance(ts.type, str) - assert not any([ts in dts.waiting_on for dts in ts.dependents if dts.waiting_on is not None]) + assert not any( + [ + ts in dts.waiting_on + for dts in ts.dependents + if dts.waiting_on is not None + ] + ) for ws in ts.who_has: assert ts in ws.has_what, ( "not in who_has' has_what", @@ -8862,7 +9282,9 @@ def heartbeat_interval(n: int) -> float: def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int: """Number of tasks that can be sent to this worker without oversaturating it""" assert not math.isinf(saturation_factor) - return max(math.ceil(saturation_factor * ws.nthreads), 1) - (len(ws.processing) - len(ws.long_running)) + return max(math.ceil(saturation_factor * ws.nthreads), 1) - ( + len(ws.processing) - len(ws.long_running) + ) def _worker_full(ws: WorkerState, saturation_factor: float) -> bool: @@ -8906,7 +9328,9 @@ def __init__( resource_restrictions: dict[str, float], timeout: float, ): - super().__init__(task, host_restrictions, worker_restrictions, resource_restrictions, timeout) + super().__init__( + task, host_restrictions, worker_restrictions, resource_restrictions, timeout + ) @property def task(self) -> Key: From 21ba1ab0c6281e7ca85f5f2d2f6e34347b306042 Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 19:05:37 -0800 Subject: [PATCH 5/7] Update condition.py,test_condition.py --- distributed/condition.py | 3 +- distributed/tests/test_condition.py | 150 +++++++++++++--------------- 2 files changed, 73 insertions(+), 80 deletions(-) diff --git a/distributed/condition.py b/distributed/condition.py index ad31630815..4d250b7f2a 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -6,8 +6,7 @@ from collections import defaultdict from contextlib import suppress -from distributed.utils import log_errors, wait_for, TimeoutError -from distributed.utils import SyncMethodMixin +from distributed.utils import SyncMethodMixin, log_errors from distributed.worker import get_client logger = logging.getLogger(__name__) diff --git a/distributed/tests/test_condition.py b/distributed/tests/test_condition.py index ea8e051796..102e629fbf 100644 --- a/distributed/tests/test_condition.py +++ b/distributed/tests/test_condition.py @@ -1,16 +1,17 @@ import asyncio + import pytest -from distributed import Condition, Client, wait -from distributed.utils_test import gen_cluster, inc +from distributed import Condition from distributed.metrics import time +from distributed.utils_test import gen_cluster @gen_cluster(client=True) -async def test_condition_acqui re_release(c, s, a, b): +async def test_condition_acquire_release(c, s, a, b): """Test basic lock acquire/release""" condition = Condition("test-lock") - + assert not condition.locked() await condition.acquire() assert condition.locked() @@ -22,7 +23,7 @@ async def test_condition_acqui re_release(c, s, a, b): async def test_condition_context_manager(c, s, a, b): """Test context manager interface""" condition = Condition("test-context") - + assert not condition.locked() async with condition: assert condition.locked() @@ -34,19 +35,19 @@ async def test_condition_wait_notify(c, s, a, b): """Test basic wait/notify""" condition = Condition("test-notify") results = [] - + async def waiter(): async with condition: results.append("waiting") await condition.wait() results.append("notified") - + async def notifier(): await asyncio.sleep(0.2) async with condition: results.append("notifying") condition.notify() - + await asyncio.gather(waiter(), notifier()) assert results == ["waiting", "notifying", "notified"] @@ -56,20 +57,18 @@ async def test_condition_notify_all(c, s, a, b): """Test notify_all wakes all waiters""" condition = Condition("test-notify-all") results = [] - + async def waiter(i): async with condition: await condition.wait() results.append(i) - + async def notifier(): await asyncio.sleep(0.2) async with condition: condition.notify_all() - - await asyncio.gather( - waiter(1), waiter(2), waiter(3), notifier() - ) + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) assert sorted(results) == [1, 2, 3] @@ -78,12 +77,12 @@ async def test_condition_notify_n(c, s, a, b): """Test notify with specific count""" condition = Condition("test-notify-n") results = [] - + async def waiter(i): async with condition: await condition.wait() results.append(i) - + async def notifier(): await asyncio.sleep(0.2) async with condition: @@ -91,10 +90,8 @@ async def notifier(): await asyncio.sleep(0.2) async with condition: condition.notify() # Wake remaining waiter - - await asyncio.gather( - waiter(1), waiter(2), waiter(3), notifier() - ) + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) assert sorted(results) == [1, 2, 3] @@ -102,12 +99,12 @@ async def notifier(): async def test_condition_wait_timeout(c, s, a, b): """Test wait with timeout""" condition = Condition("test-timeout") - + start = time() async with condition: result = await condition.wait(timeout=0.5) elapsed = time() - start - + assert result is False assert 0.4 < elapsed < 0.7 @@ -117,21 +114,21 @@ async def test_condition_wait_timeout_then_notify(c, s, a, b): """Test that timeout doesn't prevent subsequent notifications""" condition = Condition("test-timeout-notify") results = [] - + async def waiter(): async with condition: result = await condition.wait(timeout=0.2) results.append(f"timeout: {result}") - + async with condition: result = await condition.wait() results.append(f"notified: {result}") - + async def notifier(): await asyncio.sleep(0.5) async with condition: condition.notify() - + await asyncio.gather(waiter(), notifier()) assert results == ["timeout: False", "notified: True"] @@ -140,13 +137,13 @@ async def notifier(): async def test_condition_error_without_lock(c, s, a, b): """Test errors when calling wait/notify without holding lock""" condition = Condition("test-error") - + with pytest.raises(RuntimeError, match="without holding the lock"): await condition.wait() - + with pytest.raises(RuntimeError, match="Cannot notify"): await condition.notify() - + with pytest.raises(RuntimeError, match="Cannot notify"): await condition.notify_all() @@ -155,7 +152,7 @@ async def test_condition_error_without_lock(c, s, a, b): async def test_condition_error_release_without_acquire(c, s, a, b): """Test error when releasing without acquiring""" condition = Condition("test-release-error") - + with pytest.raises(RuntimeError, match="Cannot release"): await condition.release() @@ -165,14 +162,14 @@ async def test_condition_producer_consumer(c, s, a, b): """Test classic producer-consumer pattern""" condition = Condition("prod-cons") queue = [] - + async def producer(): for i in range(5): await asyncio.sleep(0.1) async with condition: queue.append(i) condition.notify() - + async def consumer(): results = [] for _ in range(5): @@ -181,13 +178,13 @@ async def consumer(): await condition.wait() results.append(queue.pop(0)) return results - + prod_task = asyncio.create_task(producer()) cons_task = asyncio.create_task(consumer()) - + await prod_task results = await cons_task - + assert results == [0, 1, 2, 3, 4] @@ -196,14 +193,14 @@ async def test_condition_multiple_producers_consumers(c, s, a, b): """Test multiple producers and consumers""" condition = Condition("multi-prod-cons") queue = [] - + async def producer(start): for i in range(start, start + 3): await asyncio.sleep(0.05) async with condition: queue.append(i) condition.notify() - + async def consumer(): results = [] for _ in range(3): @@ -212,12 +209,9 @@ async def consumer(): await condition.wait() results.append(queue.pop(0)) return results - - results = await asyncio.gather( - producer(0), producer(10), - consumer(), consumer() - ) - + + results = await asyncio.gather(producer(0), producer(10), consumer(), consumer()) + # Last two results are from consumers consumed = results[2] + results[3] assert sorted(consumed) == [0, 1, 2, 10, 11, 12] @@ -226,39 +220,43 @@ async def consumer(): @gen_cluster(client=True) async def test_condition_from_worker(c, s, a, b): """Test condition accessed from worker tasks""" + def wait_on_condition(name): + from distributed import Condition - import asyncio - + async def _wait(): condition = Condition(name) async with condition: await condition.wait() return "worker_notified" - + from distributed.worker import get_worker + worker = get_worker() return worker.loop.run_until_complete(_wait()) - + def notify_condition(name): - from distributed import Condition import asyncio - + + from distributed import Condition + async def _notify(): await asyncio.sleep(0.2) condition = Condition(name) async with condition: condition.notify() return "notified" - + from distributed.worker import get_worker + worker = get_worker() return worker.loop.run_until_complete(_notify()) - + name = "worker-condition" f1 = c.submit(wait_on_condition, name, workers=[a.address]) f2 = c.submit(notify_condition, name, workers=[b.address]) - + results = await c.gather([f1, f2]) assert results == ["worker_notified", "notified"] @@ -269,21 +267,21 @@ async def test_condition_same_name_different_instances(c, s, a, b): name = "shared-condition" cond1 = Condition(name) cond2 = Condition(name) - + results = [] - + async def waiter(): async with cond1: results.append("waiting") await cond1.wait() results.append("notified") - + async def notifier(): await asyncio.sleep(0.2) async with cond2: results.append("notifying") cond2.notify() - + await asyncio.gather(waiter(), notifier()) assert results == ["waiting", "notifying", "notified"] @@ -293,11 +291,11 @@ async def test_condition_unique_names_independent(c, s, a, b): """Test conditions with different names are independent""" cond1 = Condition("cond-1") cond2 = Condition("cond-2") - + async with cond1: assert cond1.locked() assert not cond2.locked() - + async with cond2: assert not cond1.locked() assert cond2.locked() @@ -307,15 +305,15 @@ async def test_condition_unique_names_independent(c, s, a, b): async def test_condition_cleanup(c, s, a, b): """Test that condition state is cleaned up after use""" condition = Condition("cleanup-test") - + # Check initial state assert "cleanup-test" not in s.extensions["conditions"]._lock_holders assert "cleanup-test" not in s.extensions["conditions"]._waiters - + # Use condition async with condition: condition.notify() - + # State should be cleaned up await asyncio.sleep(0.1) assert "cleanup-test" not in s.extensions["conditions"]._lock_holders @@ -327,7 +325,7 @@ async def test_condition_barrier_pattern(c, s, a, b): condition = Condition("barrier") arrived = [] n_workers = 3 - + async def worker(i): async with condition: arrived.append(i) @@ -336,11 +334,9 @@ async def worker(i): else: condition.notify_all() return f"worker-{i}-done" - - results = await asyncio.gather( - worker(0), worker(1), worker(2) - ) - + + results = await asyncio.gather(worker(0), worker(1), worker(2)) + assert sorted(results) == ["worker-0-done", "worker-1-done", "worker-2-done"] assert len(arrived) == 3 @@ -349,12 +345,12 @@ def test_condition_sync_interface(client): """Test synchronous interface via SyncMethodMixin""" condition = Condition("sync-test") results = [] - + def worker(): with condition: results.append("locked") results.append("released") - + worker() assert results == ["locked", "released"] @@ -364,12 +360,12 @@ async def test_condition_multiple_notify_calls(c, s, a, b): """Test multiple notify calls in sequence""" condition = Condition("multi-notify") results = [] - + async def waiter(i): async with condition: await condition.wait() results.append(i) - + async def notifier(): await asyncio.sleep(0.2) async with condition: @@ -380,10 +376,8 @@ async def notifier(): await asyncio.sleep(0.1) async with condition: condition.notify() - - await asyncio.gather( - waiter(1), waiter(2), waiter(3), notifier() - ) + + await asyncio.gather(waiter(1), waiter(2), waiter(3), notifier()) assert sorted(results) == [1, 2, 3] @@ -392,20 +386,20 @@ async def test_condition_predicate_loop(c, s, a, b): """Test typical predicate-based wait loop pattern""" condition = Condition("predicate") state = {"value": 0, "target": 5} - + async def waiter(): async with condition: while state["value"] < state["target"]: await condition.wait() return state["value"] - + async def updater(): for i in range(1, 6): await asyncio.sleep(0.1) async with condition: state["value"] = i condition.notify_all() - + result, _ = await asyncio.gather(waiter(), updater()) assert result == 5 From e978821ec3eef65f08da17ca7e08fb9b4b584f1f Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 19:09:19 -0800 Subject: [PATCH 6/7] Update scheduler.py --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f2ad5080a8..07951263f4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -97,6 +97,7 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback +from distributed.condition import ConditionExtension from distributed.core import ( ErrorMessage, OKMessage, @@ -144,7 +145,6 @@ scatter_to_workers, ) from distributed.variable import VariableExtension -from distributed.condition import ConditionExtension if TYPE_CHECKING: from typing import TypeAlias, TypeVar From 34f5cb45ce13e69eb7996bcf9bd26dddcf3eb8ef Mon Sep 17 00:00:00 2001 From: nadzhou Date: Fri, 5 Dec 2025 19:23:33 -0800 Subject: [PATCH 7/7] Update condition.py --- distributed/condition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/condition.py b/distributed/condition.py index 4d250b7f2a..24d9a8a022 100644 --- a/distributed/condition.py +++ b/distributed/condition.py @@ -158,7 +158,9 @@ async def wait(self, timeout=None): raise RuntimeError("Cannot wait on un-acquired condition") scheduler = self._get_scheduler_rpc() - result = await scheduler.condition_wait(name=self.name, id=self.id, timeout=timeout) + result = await scheduler.condition_wait( + name=self.name, id=self.id, timeout=timeout + ) return result async def notify(self, n=1):