Skip to content

Commit

Permalink
TaskID
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 24, 2024
1 parent 956b9aa commit 67af153
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions adaptive_scheduler/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from concurrent.futures import Executor, Future
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, NamedTuple

from adaptive import SequenceLearner

Expand Down Expand Up @@ -72,27 +72,32 @@ def shutdown(
self._run_manager.cancel()


class TaskID(NamedTuple):
learner_inded: int
sequence_index: int


class SLURMTask(Future):
"""A `Future` that loads the result from a `SequenceLearner`."""

__slots__ = ("executor", "id_", "_state", "_last_mtime", "min_load_interval")
__slots__ = ("executor", "task_id", "_state", "_last_mtime", "min_load_interval")

def __init__(
self,
executor: SLURMExecutor,
id_: tuple[int, int],
task_id: TaskID,
min_load_interval: float = 1.0,
) -> None:
super().__init__()
self.executor = executor
self.id_ = id_
self.task_id = task_id
self.min_load_interval: float = min_load_interval
self._state: Literal["PENDING", "RUNNING", "FINISHED", "CANCELLED"] = "PENDING"
self._last_mtime: float = 0
self.min_load_interval: float = min_load_interval

def _get(self) -> Any | None:
"""Updates the state of the task and returns the result if the task is finished."""
i_learner, index = self.id_
i_learner, index = self.task_id
learner, fname = self._learner_and_fname(load=False)
if self._state == "FINISHED":
return learner.data[index]
Expand Down Expand Up @@ -125,13 +130,13 @@ def _get(self) -> Any | None:
def __repr__(self) -> str:
if self._state == "PENDING":
self._get()
return f"SLURMTask(id_={self.id_}, state={self._state})"
return f"SLURMTask(task_id={self.task_id}, state={self._state})"

def __str__(self) -> str:
return self.__repr__()

def _learner_and_fname(self, *, load: bool = True) -> tuple[SequenceLearner, str | Path]:
i_learner, _ = self.id_
i_learner, _ = self.task_id
run_manager = self.executor._run_manager
assert run_manager is not None
learner: SequenceLearner = run_manager.learners[i_learner] # type: ignore[index]
Expand Down Expand Up @@ -331,8 +336,8 @@ def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> SLURMT
sequence = self._sequences.setdefault(fn, [])
i = len(sequence)
sequence.append(args)
id_ = (self._sequence_mapping[fn], i)
return SLURMTask(self, id_)
task_id = TaskID(self._sequence_mapping[fn], i)
return SLURMTask(self, task_id)

def _to_learners(self) -> tuple[list[SequenceLearner], list[Path]]:
learners = []
Expand Down

0 comments on commit 67af153

Please sign in to comment.