diff --git a/adaptive_scheduler/_executor.py b/adaptive_scheduler/_executor.py index df098461..19a0f91f 100644 --- a/adaptive_scheduler/_executor.py +++ b/adaptive_scheduler/_executor.py @@ -10,7 +10,7 @@ from concurrent.futures import Executor, Future from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple from adaptive import SequenceLearner @@ -82,7 +82,7 @@ class TaskID(NamedTuple): class SlurmTask(Future): """A `Future` that loads the result from a `SequenceLearner`.""" - __slots__ = ("executor", "task_id", "_state", "_last_size", "min_load_interval") + __slots__ = ("executor", "task_id", "_state", "_last_size", "min_load_interval", "_load_time") def __init__( self, @@ -94,8 +94,8 @@ def __init__( self.executor = executor self.task_id = task_id self.min_load_interval: float = min_load_interval - self._state: Literal["PENDING", "RUNNING", "FINISHED", "CANCELLED"] = "PENDING" self._last_size: float = 0 + self._load_time: float = 0 def _get(self) -> Any | None: # noqa: PLR0911 """Updates the state of the task and returns the result if the task is finished.""" @@ -126,6 +126,8 @@ def _get(self) -> Any | None: # noqa: PLR0911 self._last_size = size learner.load(fname) + self._load_time = time.monotonic() - now + self.min_load_interval = max(1.0, 20 * self._load_time) self.executor._run_manager._last_load_time[idx_learner] = now if idx_data in learner.data: