Skip to content

Commit

Permalink
Allow multiple args in *args for SlurmExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 30, 2024
1 parent 95f64e6 commit 12b815f
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions adaptive_scheduler/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple

import cloudpickle
from adaptive import SequenceLearner

import adaptive_scheduler
Expand Down Expand Up @@ -191,6 +192,20 @@ def _uuid_with_datetime() -> str:
return f"{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}-{uuid.uuid4().hex}" # noqa: DTZ005


class _SerializableFunctionSplatter:
def __init__(self, func: Callable[..., Any]) -> None:
self.func = func

def __call__(self, args: Any) -> Any:
return self.func(*args)

def __getstate__(self) -> dict[str, Any]:
return cloudpickle.dumps(self.func)

def __setstate__(self, state: bytes) -> None:
self.func = cloudpickle.loads(state)


@dataclass
class SlurmExecutor(AdaptiveSchedulerExecutorBase):
"""An executor that runs jobs on SLURM.
Expand Down Expand Up @@ -347,20 +362,17 @@ def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> SlurmT
raise ValueError(msg)
if fn not in self._sequence_mapping:
self._sequence_mapping[fn] = len(self._sequence_mapping)
if len(args) != 1:
msg = "Exactly one argument is required"
raise ValueError(msg)
sequence = self._sequences.setdefault(fn, [])
i = len(sequence)
sequence.append(args[0])
sequence.append(args)
task_id = TaskID(self._sequence_mapping[fn], i)
return SlurmTask(self, task_id)

def _to_learners(self) -> tuple[list[SequenceLearner], list[Path]]:
learners = []
fnames = []
for func, args_kwargs_list in self._sequences.items():
learner = SequenceLearner(func, args_kwargs_list)
learner = SequenceLearner(_SerializableFunctionSplatter(func), args_kwargs_list)
learners.append(learner)
assert isinstance(self.folder, Path)
name = func.__name__ if hasattr(func, "__name__") else ""
Expand Down Expand Up @@ -420,10 +432,11 @@ def cleanup(self) -> None:
assert self._run_manager is not None
self._run_manager.cleanup(remove_old_logs_folder=True)

def new(self) -> SlurmExecutor:
def new(self, update: dict[str, Any]) -> SlurmExecutor:
"""Create a new SlurmExecutor with the same parameters."""
data = asdict(self)
data["_run_manager"] = None
data["_sequences"] = {}
data["_sequence_mapping"] = {}
data.update(update)
return SlurmExecutor(**data)

0 comments on commit 12b815f

Please sign in to comment.