Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial MultiRunManager #242

Merged
merged 38 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
66f6fc7
Add MultiRunManager
basnijholt Oct 18, 2024
06f8e01
test
basnijholt Oct 18, 2024
cc3e3eb
Merge remote-tracking branch 'origin/main' into multi-run-manager
basnijholt Oct 18, 2024
e2fe1e0
.
basnijholt Oct 18, 2024
75bdfae
.
basnijholt Oct 18, 2024
a490bf2
.
basnijholt Oct 18, 2024
4f14ab9
.
basnijholt Oct 19, 2024
bb61bda
executor
basnijholt Oct 23, 2024
ac33e8d
fix types
basnijholt Oct 23, 2024
a9d7126
awaitable
basnijholt Oct 23, 2024
4d5b22b
throttle
basnijholt Oct 23, 2024
bfe3d59
await
basnijholt Oct 23, 2024
fc2fe45
set_result
basnijholt Oct 23, 2024
b4b2844
throttle load
basnijholt Oct 24, 2024
94ea467
set
basnijholt Oct 24, 2024
e78c5d2
slots
basnijholt Oct 24, 2024
f41cea6
.
basnijholt Oct 24, 2024
b8aec70
order
basnijholt Oct 24, 2024
5f2a8b6
.
basnijholt Oct 24, 2024
6930961
docs
basnijholt Oct 24, 2024
59558ca
.
basnijholt Oct 24, 2024
24b475f
commemt
basnijholt Oct 24, 2024
956b9aa
add str
basnijholt Oct 24, 2024
67af153
TaskID
basnijholt Oct 24, 2024
9d2378d
cleanup
basnijholt Oct 24, 2024
678fd22
rename
basnijholt Oct 24, 2024
728d330
add new
basnijholt Oct 24, 2024
860e5f9
revert
basnijholt Oct 24, 2024
a4f485c
check filesize instead
basnijholt Oct 30, 2024
6390f15
.
basnijholt Oct 30, 2024
e8a8784
fix load
basnijholt Oct 30, 2024
4da2e42
Auto-scale load time
basnijholt Oct 30, 2024
b7f3efa
.
basnijholt Oct 30, 2024
6fd8c64
Merge remote-tracking branch 'origin/main' into multi-run-manager
basnijholt Oct 30, 2024
a212dbb
fix
basnijholt Oct 30, 2024
272ec5d
multiple args
basnijholt Oct 30, 2024
ce7dcbc
rename
basnijholt Oct 30, 2024
483cd33
Merge remote-tracking branch 'origin/main' into multi-run-manager
basnijholt Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions adaptive_scheduler/_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
from __future__ import annotations

import abc
import asyncio
import os
import time
import uuid
from concurrent.futures import Executor, Future
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from adaptive import SequenceLearner

import adaptive_scheduler

if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from adaptive_scheduler.utils import (
_DATAFRAME_FORMATS,
EXECUTOR_TYPES,
LOKY_START_METHODS,
GoalTypes,
)


class AdaptiveSchedulerExecutorBase(Executor):
_run_manager: adaptive_scheduler.RunManager | None

@abc.abstractmethod
def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future:
pass

@abc.abstractmethod
def finalize(self, *, start: bool = True) -> adaptive_scheduler.RunManager:
"""Finalize the executor and return the RunManager."""

def map( # type: ignore[override]
self,
fn: Callable[..., Any],
/,
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> list[Future]:
tasks = []
if timeout is not None:
msg = "Timeout not implemented"
raise NotImplementedError(msg)
if chunksize != 1:
msg = "Chunksize not implemented"
raise NotImplementedError(msg)
for args in zip(*iterables, strict=True):
task = self.submit(fn, *args)
tasks.append(task)
return tasks

def shutdown(
self,
wait: bool = True, # noqa: FBT001, FBT002
*,
cancel_futures: bool = False,
) -> None:
if not wait:
msg = "Non-waiting shutdown not implemented"
raise NotImplementedError(msg)
if cancel_futures:
msg = "Cancelling futures not implemented"
raise NotImplementedError(msg)
if self._run_manager is not None:
self._run_manager.cancel()


class SLURMTask(Future):
"""A `Future` that loads the result from a `SequenceLearner`."""

__slots__ = ("executor", "id_", "_state", "_last_mtime", "min_load_interval")

def __init__(
self,
executor: SLURMExecutor,
id_: tuple[int, int],
min_load_interval: float = 1.0,
) -> None:
super().__init__()
self.executor = executor
self.id_ = id_
self._state: Literal["PENDING", "RUNNING", "FINISHED", "CANCELLED"] = "PENDING"
self._last_mtime: float = 0
self.min_load_interval: float = min_load_interval

def _get(self) -> Any | None:
"""Updates the state of the task and returns the result if the task is finished."""
i_learner, index = self.id_
learner, fname = self._learner_and_fname(load=False)
assert self.executor._run_manager is not None
last_load_time = self.executor._run_manager._last_load_time.get(i_learner, 0)
now = time.monotonic()
time_since_last_load = now - last_load_time
if time_since_last_load < self.min_load_interval:
return None
if self._state == "FINISHED":
return learner.data[index]

try:
mtime = os.path.getmtime(fname) # noqa: PTH204
except FileNotFoundError:
return None

if self._last_mtime == mtime:
return None

self._last_mtime = mtime
learner.load(fname)
self.executor._run_manager._last_load_time[i_learner] = now

if index in learner.data:
result = learner.data[index]
self.set_result(result)
return result
return None

def __repr__(self) -> str:
if self._state == "PENDING":
self._get()
return f"SLURMTask(id_={self.id_}, state={self._state})"

def _learner_and_fname(self, *, load: bool = True) -> tuple[SequenceLearner, str | Path]:
i_learner, _ = self.id_
run_manager = self.executor._run_manager
assert run_manager is not None
learner = run_manager.learners[i_learner]
fname = run_manager.fnames[i_learner]
if load and not learner.done():
learner.load(fname)
return learner, fname

def result(self, timeout: float | None = None) -> Any:
if timeout is not None:
msg = "Timeout not implemented for SLURMTask"
raise NotImplementedError(msg)
if self.executor._run_manager is None:
msg = "RunManager not initialized. Call finalize() first."
raise RuntimeError(msg)
result = self._get()
if self._state == "FINISHED":
return result
msg = "Task not finished"
raise RuntimeError(msg)

def __await__(self) -> Any:
def wakeup() -> None:
if not self.done():
self._get()
loop.call_later(1, wakeup) # Schedule next check after 1 second
else:
fut.set_result(self.result())

loop = asyncio.get_event_loop()
fut = loop.create_future()
loop.call_soon(wakeup)
yield from fut
return self.result()

async def __aiter__(self) -> Any:
await self
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed
return self.result()


@dataclass
class SLURMExecutor(AdaptiveSchedulerExecutorBase):
# Same as slurm_run, except it has no dependencies and initializers.
# Additionally, the type hints for scheduler arguments are singular instead of tuples.

# Specific to slurm_run
name: str = "adaptive-scheduler"
folder: str | Path = ""
# SLURM scheduler arguments
partition: str | None = None
nodes: int | None = 1
cores_per_node: int | None = None
num_threads: int = 1
exclusive: bool = False
executor_type: EXECUTOR_TYPES = "process-pool"
extra_scheduler: list[str] | None = None
# Same as RunManager below (except dependencies and initializers)
goal: GoalTypes | None = None
check_goal_on_start: bool = True
runner_kwargs: dict | None = None
url: str | None = None
save_interval: float = 300
log_interval: float = 300
job_manager_interval: float = 60
kill_interval: float = 60
kill_on_error: str | Callable[[list[str]], bool] | None = "srun: error:"
overwrite_db: bool = True
job_manager_kwargs: dict[str, Any] | None = None
kill_manager_kwargs: dict[str, Any] | None = None
loky_start_method: LOKY_START_METHODS = "loky"
cleanup_first: bool = True
save_dataframe: bool = True
dataframe_format: _DATAFRAME_FORMATS = "pickle"
max_log_lines: int = 500
max_fails_per_job: int = 50
max_simultaneous_jobs: int = 100
quiet: bool = True # `slurm_run` defaults to `False`
# RunManager arguments
extra_run_manager_kwargs: dict[str, Any] | None = None
extra_scheduler_kwargs: dict[str, Any] | None = None
# Internal
_sequences: dict[Callable[..., Any], list[Any]] = field(default_factory=dict)
_sequence_mapping: dict[Callable[..., Any], int] = field(default_factory=dict)
_run_manager: adaptive_scheduler.RunManager | None = None

def __post_init__(self) -> None:
if self.folder is None:
self.folder = Path.cwd() / ".adaptive_scheduler" / uuid.uuid4().hex # type: ignore[operator]
else:
self.folder = Path(self.folder)

def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> SLURMTask:
if kwargs:
msg = "Keyword arguments are not supported"
raise ValueError(msg)
if fn not in self._sequence_mapping:
self._sequence_mapping[fn] = len(self._sequence_mapping)
sequence = self._sequences.setdefault(fn, [])
i = len(sequence)
sequence.append(args)
id_ = (self._sequence_mapping[fn], i)
return SLURMTask(self, id_)

def _to_learners(self) -> tuple[list[SequenceLearner], list[Path]]:
learners = []
fnames = []
for func, args_kwargs_list in self._sequences.items():
learner = SequenceLearner(func, args_kwargs_list)
learners.append(learner)
assert isinstance(self.folder, Path)
fnames.append(self.folder / f"{func.__name__}.pickle")
return learners, fnames

def finalize(self, *, start: bool = True) -> adaptive_scheduler.RunManager:
learners, fnames = self._to_learners()
assert self.folder is not None
self._run_manager = adaptive_scheduler.slurm_run(
learners=learners,
fnames=fnames,
# Specific to slurm_run
name=self.name,
folder=self.folder,
# SLURM scheduler arguments
partition=self.partition,
nodes=self.nodes,
cores_per_node=self.cores_per_node,
num_threads=self.num_threads,
exclusive=self.exclusive,
executor_type=self.executor_type,
extra_scheduler=self.extra_scheduler,
# Same as RunManager below (except job_name, move_old_logs_to, and db_fname)
goal=self.goal,
check_goal_on_start=self.check_goal_on_start,
runner_kwargs=self.runner_kwargs,
url=self.url,
save_interval=self.save_interval,
log_interval=self.log_interval,
job_manager_interval=self.job_manager_interval,
kill_interval=self.kill_interval,
kill_on_error=self.kill_on_error,
overwrite_db=self.overwrite_db,
job_manager_kwargs=self.job_manager_kwargs,
kill_manager_kwargs=self.kill_manager_kwargs,
loky_start_method=self.loky_start_method,
cleanup_first=self.cleanup_first,
save_dataframe=self.save_dataframe,
dataframe_format=self.dataframe_format,
max_log_lines=self.max_log_lines,
max_fails_per_job=self.max_fails_per_job,
max_simultaneous_jobs=self.max_simultaneous_jobs,
quiet=self.quiet,
# RunManager arguments
extra_run_manager_kwargs=self.extra_run_manager_kwargs,
extra_scheduler_kwargs=self.extra_scheduler_kwargs,
)
if start:
self._run_manager.start()
return self._run_manager

def cleanup(self) -> None: ...
Fixed Show fixed Hide fixed
5 changes: 3 additions & 2 deletions adaptive_scheduler/_server_support/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
dependencies: dict[int, list[int]] | None = None,
overwrite_db: bool = True,
initializers: list[Callable[[], None]] | None = None,
with_progress_bar: bool = True,
) -> None:
super().__init__()
self.url = url
Expand All @@ -194,7 +195,7 @@ def __init__(
self.dependencies = dependencies or {}
self.overwrite_db = overwrite_db
self.initializers = initializers

self.with_progress_bar = with_progress_bar
self._last_reply: str | list[str] | Exception | None = None
self._last_request: tuple[str, ...] | None = None
self.failed: list[dict[str, Any]] = []
Expand All @@ -210,7 +211,7 @@ def _setup(self) -> None:
self.learners,
self.fnames,
initializers=self.initializers,
with_progress_bar=True,
with_progress_bar=self.with_progress_bar,
)

def update(self, queue: dict[str, dict[str, str]] | None = None) -> None:
Expand Down
Loading
Loading