diff --git a/neps/runtime.py b/neps/runtime.py index c9988f70..c7298530 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -6,9 +6,10 @@ import logging import os import shutil +import signal import time from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import ( TYPE_CHECKING, @@ -22,6 +23,8 @@ TypeVar, ) +from pandas.core.common import contextlib + from neps.exceptions import ( NePSError, VersionMismatchError, @@ -45,6 +48,12 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +SIGNALS_TO_HANDLE_IF_AVAILABLE = [ + "SIGINT", + "SIGTERM", +] + + N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 10 N_FAILED_TO_SET_TRIAL_STATE = 10 @@ -145,6 +154,8 @@ class DefaultWorker(Generic[Loc]): worker_cumulative_evaluation_time_seconds: float = 0.0 """The time spent evaluating configurations by this worker.""" + _PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict) + @classmethod def new( cls, @@ -337,12 +348,23 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return False + def _set_signal_handlers(self) -> None: + for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: + if hasattr(signal.Signals, name): + sig = getattr(signal.Signals, name) + # HACK: Despite what python documentation says, the existance of a signal + # is not enough to guarantee that it can be caught. + with contextlib.suppress(ValueError): + previous_signal_handler = signal.signal(sig, self._emergency_cleanup) + self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler + def run(self) -> None: # noqa: C901, PLR0915 """Run the worker. Will keep running until one of the criterion defined by the `WorkerSettings` is met. """ + self._set_signal_handlers() _set_workers_neps_state(self.state) logger.info("Launching NePS") @@ -416,15 +438,21 @@ def run(self) -> None: # noqa: C901, PLR0915 continue # We (this worker) has managed to set it to evaluating, now we can evaluate it - with _set_global_trial(trial_to_eval): - evaluated_trial, report = evaluate_trial( - trial=trial_to_eval, - evaluation_fn=self.evaluation_fn, - default_report_values=self.settings.default_report_values, - ) - evaluation_duration = evaluated_trial.metadata.evaluation_duration - assert evaluation_duration is not None - self.worker_cumulative_evaluation_time_seconds += evaluation_duration + try: + with _set_global_trial(trial_to_eval): + evaluated_trial, report = evaluate_trial( + trial=trial_to_eval, + evaluation_fn=self.evaluation_fn, + default_report_values=self.settings.default_report_values, + ) + except KeyboardInterrupt as e: + # This throws and we have stopped the worker at this point + self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) + return + + evaluation_duration = evaluated_trial.metadata.evaluation_duration + assert evaluation_duration is not None + self.worker_cumulative_evaluation_time_seconds += evaluation_duration self.worker_cumulative_eval_count += 1 @@ -460,6 +488,39 @@ def run(self) -> None: # noqa: C901, PLR0915 "Learning Curve %s: %s", evaluated_trial.id, report.learning_curve ) + def _emergency_cleanup( + self, + signum: int, + frame: Any, + rethrow: KeyboardInterrupt | None = None, + ) -> None: + """Handle signals.""" + global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 + logger.error( + f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!" + ) + if _CURRENTLY_RUNNING_TRIAL_IN_PROCESS is not None: + logger.error( + "Worker '%s' was interrupted while evaluating trial: %s. Setting" + " trial to pending!", + self.worker_id, + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS.id, + ) + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS.reset() + try: + self.state.put_updated_trial(_CURRENTLY_RUNNING_TRIAL_IN_PROCESS) + except NePSError as e: + logger.exception(e) + finally: + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None + + previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum) + if previous_handler is not None and callable(previous_handler): + previous_handler(signum, frame) + if rethrow is not None: + raise rethrow + raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.") + # TODO: This should be done directly in `api.run` at some point to make it clearer at an # entryy point how the woerer is set up to run if someone reads the entry point code. diff --git a/pyproject.toml b/pyproject.toml index 06b4baa4..400aaec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ pre-commit = "^3" mypy = "^1" pytest = "^7" pytest-cases = "^3" +psutil = "^6" types-PyYAML = "^6" mkdocs-material = "*" mkdocs-autorefs = "*" diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 5e819448..d357ec90 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -1,13 +1,18 @@ from __future__ import annotations import pytest +import os from dataclasses import dataclass from pandas.core.common import contextlib +import signal from pathlib import Path from pytest_cases import fixture, parametrize +import multiprocessing +import psutil +import time from neps.optimizers.random_search.optimizer import RandomSearch -from neps.runtime import DefaultWorker +from neps.runtime import DefaultWorker, SIGNALS_TO_HANDLE_IF_AVAILABLE from neps.search_spaces.search_space import SearchSpace from neps.state.err_dump import SerializedError from neps.state.filebased import create_or_load_filebased_neps_state @@ -126,14 +131,7 @@ def evaler(*args, **kwargs) -> float: assert len(neps_state.get_errors()) == 1 -@pytest.mark.parametrize( - "on_error", - [OnErrorPossibilities.IGNORE, OnErrorPossibilities.RAISE_WORKER_ERROR], -) -def test_worker_does_not_raise_when_error_in_other_worker( - neps_state: NePSState, - on_error: OnErrorPossibilities, -) -> None: +def test_worker_does_not_raise_when_error_in_other_worker(neps_state: NePSState) -> None: optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1))) settings = WorkerSettings( on_error=OnErrorPossibilities.RAISE_WORKER_ERROR, # <- Highlight @@ -198,3 +196,71 @@ def __call__(self, *args, **kwargs) -> float: assert neps_state.get_next_pending_trial() is None assert len(neps_state.get_errors()) == 1 + + +def sleep_function(*args, **kwargs) -> float: + time.sleep(20) + return 10 + + +SIGNALS: list[signal.Signals] = [] +for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: + if hasattr(signal.Signals, name): + sig: signal.Signals = getattr(signal.Signals, name) + SIGNALS.append(sig) + + +@pytest.mark.ci_examples +@pytest.mark.parametrize("signum", SIGNALS) +def test_worker_reset_evaluating_to_pending_on_ctrl_c( + signum: signal.Signals, + neps_state: NePSState, +) -> None: + optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1))) + settings = WorkerSettings( + on_error=OnErrorPossibilities.IGNORE, # <- Highlight + default_report_values=DefaultReportValues(), + max_evaluations_total=None, + include_in_progress_evaluations_towards_maximum=False, + max_cost_total=None, + max_evaluations_for_worker=1, + max_evaluation_time_total_seconds=None, + max_wallclock_time_for_worker_seconds=None, + max_evaluation_time_for_worker_seconds=None, + max_cost_for_worker=None, + ) + + worker1 = DefaultWorker.new( + state=neps_state, + optimizer=optimizer, + evaluation_fn=sleep_function, + settings=settings, + _pre_sample_hooks=None, + ) + + # Use multiprocessing.Process + p = multiprocessing.Process(target=worker1.run) + p.start() + + time.sleep(5) + assert p.pid is not None + assert p.is_alive() + + # Should be evaluating at this stage + trials = neps_state.get_all_trials() + assert len(trials) == 1 + assert next(iter(trials.values())).state == Trial.State.EVALUATING + + # Kill the process while it's evaluating using signals + process = psutil.Process(p.pid) + process.send_signal(signum) + p.join(timeout=10) # Wait for the process to terminate + + if p.is_alive(): + p.terminate() # Force terminate if it's still alive + p.join() + pytest.fail("Worker did not terminate after receiving signal!") + else: + trials2 = neps_state.get_all_trials() + assert len(trials2) == 1 + assert next(iter(trials2.values())).state == Trial.State.PENDING