Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle KeyboardInterrupt, SIGINT, SIGTERM gracefully #129

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import logging
import os
import shutil
import signal
import time
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -22,6 +23,8 @@
TypeVar,
)

from pandas.core.common import contextlib

from neps.exceptions import (
NePSError,
VersionMismatchError,
Expand All @@ -45,6 +48,12 @@ def _default_worker_name() -> str:
return f"{os.getpid()}-{isoformat}"


SIGNALS_TO_HANDLE_IF_AVAILABLE = [
"SIGINT",
"SIGTERM",
]


N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 10
N_FAILED_TO_SET_TRIAL_STATE = 10

Expand Down Expand Up @@ -145,6 +154,8 @@ class DefaultWorker(Generic[Loc]):
worker_cumulative_evaluation_time_seconds: float = 0.0
"""The time spent evaluating configurations by this worker."""

_PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict)

@classmethod
def new(
cls,
Expand Down Expand Up @@ -337,12 +348,23 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911

return False

def _set_signal_handlers(self) -> None:
for name in SIGNALS_TO_HANDLE_IF_AVAILABLE:
if hasattr(signal.Signals, name):
sig = getattr(signal.Signals, name)
# HACK: Despite what python documentation says, the existance of a signal
# is not enough to guarantee that it can be caught.
with contextlib.suppress(ValueError):
previous_signal_handler = signal.signal(sig, self._emergency_cleanup)
self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler

def run(self) -> None: # noqa: C901, PLR0915
"""Run the worker.

Will keep running until one of the criterion defined by the `WorkerSettings`
is met.
"""
self._set_signal_handlers()
_set_workers_neps_state(self.state)

logger.info("Launching NePS")
Expand Down Expand Up @@ -416,15 +438,21 @@ def run(self) -> None: # noqa: C901, PLR0915
continue

# We (this worker) has managed to set it to evaluating, now we can evaluate it
with _set_global_trial(trial_to_eval):
evaluated_trial, report = evaluate_trial(
trial=trial_to_eval,
evaluation_fn=self.evaluation_fn,
default_report_values=self.settings.default_report_values,
)
evaluation_duration = evaluated_trial.metadata.evaluation_duration
assert evaluation_duration is not None
self.worker_cumulative_evaluation_time_seconds += evaluation_duration
try:
with _set_global_trial(trial_to_eval):
evaluated_trial, report = evaluate_trial(
trial=trial_to_eval,
evaluation_fn=self.evaluation_fn,
default_report_values=self.settings.default_report_values,
)
except KeyboardInterrupt as e:
# This throws and we have stopped the worker at this point
self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e)
return

evaluation_duration = evaluated_trial.metadata.evaluation_duration
assert evaluation_duration is not None
self.worker_cumulative_evaluation_time_seconds += evaluation_duration

self.worker_cumulative_eval_count += 1

Expand Down Expand Up @@ -460,6 +488,39 @@ def run(self) -> None: # noqa: C901, PLR0915
"Learning Curve %s: %s", evaluated_trial.id, report.learning_curve
)

def _emergency_cleanup(
self,
signum: int,
frame: Any,
rethrow: KeyboardInterrupt | None = None,
) -> None:
"""Handle signals."""
global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603
logger.error(
f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!"
)
if _CURRENTLY_RUNNING_TRIAL_IN_PROCESS is not None:
logger.error(
"Worker '%s' was interrupted while evaluating trial: %s. Setting"
" trial to pending!",
self.worker_id,
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS.id,
)
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS.reset()
try:
self.state.put_updated_trial(_CURRENTLY_RUNNING_TRIAL_IN_PROCESS)
except NePSError as e:
logger.exception(e)
finally:
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None

previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum)
if previous_handler is not None and callable(previous_handler):
previous_handler(signum, frame)
if rethrow is not None:
raise rethrow
raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.")


# TODO: This should be done directly in `api.run` at some point to make it clearer at an
# entryy point how the woerer is set up to run if someone reads the entry point code.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pre-commit = "^3"
mypy = "^1"
pytest = "^7"
pytest-cases = "^3"
psutil = "^6"
types-PyYAML = "^6"
mkdocs-material = "*"
mkdocs-autorefs = "*"
Expand Down
84 changes: 75 additions & 9 deletions tests/test_runtime/test_error_handling_strategies.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import pytest
import os
from dataclasses import dataclass
from pandas.core.common import contextlib
import signal
from pathlib import Path
from pytest_cases import fixture, parametrize
import multiprocessing
import psutil
import time

from neps.optimizers.random_search.optimizer import RandomSearch
from neps.runtime import DefaultWorker
from neps.runtime import DefaultWorker, SIGNALS_TO_HANDLE_IF_AVAILABLE
from neps.search_spaces.search_space import SearchSpace
from neps.state.err_dump import SerializedError
from neps.state.filebased import create_or_load_filebased_neps_state
Expand Down Expand Up @@ -126,14 +131,7 @@ def evaler(*args, **kwargs) -> float:
assert len(neps_state.get_errors()) == 1


@pytest.mark.parametrize(
"on_error",
[OnErrorPossibilities.IGNORE, OnErrorPossibilities.RAISE_WORKER_ERROR],
)
def test_worker_does_not_raise_when_error_in_other_worker(
neps_state: NePSState,
on_error: OnErrorPossibilities,
) -> None:
def test_worker_does_not_raise_when_error_in_other_worker(neps_state: NePSState) -> None:
optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1)))
settings = WorkerSettings(
on_error=OnErrorPossibilities.RAISE_WORKER_ERROR, # <- Highlight
Expand Down Expand Up @@ -198,3 +196,71 @@ def __call__(self, *args, **kwargs) -> float:

assert neps_state.get_next_pending_trial() is None
assert len(neps_state.get_errors()) == 1


def sleep_function(*args, **kwargs) -> float:
time.sleep(20)
return 10


SIGNALS: list[signal.Signals] = []
for name in SIGNALS_TO_HANDLE_IF_AVAILABLE:
if hasattr(signal.Signals, name):
sig: signal.Signals = getattr(signal.Signals, name)
SIGNALS.append(sig)


@pytest.mark.ci_examples
@pytest.mark.parametrize("signum", SIGNALS)
def test_worker_reset_evaluating_to_pending_on_ctrl_c(
signum: signal.Signals,
neps_state: NePSState,
) -> None:
optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1)))
settings = WorkerSettings(
on_error=OnErrorPossibilities.IGNORE, # <- Highlight
default_report_values=DefaultReportValues(),
max_evaluations_total=None,
include_in_progress_evaluations_towards_maximum=False,
max_cost_total=None,
max_evaluations_for_worker=1,
max_evaluation_time_total_seconds=None,
max_wallclock_time_for_worker_seconds=None,
max_evaluation_time_for_worker_seconds=None,
max_cost_for_worker=None,
)

worker1 = DefaultWorker.new(
state=neps_state,
optimizer=optimizer,
evaluation_fn=sleep_function,
settings=settings,
_pre_sample_hooks=None,
)

# Use multiprocessing.Process
p = multiprocessing.Process(target=worker1.run)
p.start()

time.sleep(5)
assert p.pid is not None
assert p.is_alive()

# Should be evaluating at this stage
trials = neps_state.get_all_trials()
assert len(trials) == 1
assert next(iter(trials.values())).state == Trial.State.EVALUATING

# Kill the process while it's evaluating using signals
process = psutil.Process(p.pid)
process.send_signal(signum)
p.join(timeout=10) # Wait for the process to terminate

if p.is_alive():
p.terminate() # Force terminate if it's still alive
p.join()
pytest.fail("Worker did not terminate after receiving signal!")
else:
trials2 = neps_state.get_all_trials()
assert len(trials2) == 1
assert next(iter(trials2.values())).state == Trial.State.PENDING
Loading