From 8d82d40001d0140b083a958db4553d7e597070f0 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 5 Aug 2024 20:24:33 +0200 Subject: [PATCH 01/14] fix: Handle KeyboardInterrupt, SIGINT, SIGTERM gracefully --- neps/runtime.py | 53 ++++++++++++--- .../test_error_handling_strategies.py | 64 +++++++++++++++++++ 2 files changed, 108 insertions(+), 9 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index c9988f70..ec9b6843 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -6,6 +6,7 @@ import logging import os import shutil +import signal import time from contextlib import contextmanager from dataclasses import dataclass @@ -145,6 +146,8 @@ class DefaultWorker(Generic[Loc]): worker_cumulative_evaluation_time_seconds: float = 0.0 """The time spent evaluating configurations by this worker.""" + _SIGNAL_HANDLER_FIRED: bool = False + @classmethod def new( cls, @@ -344,6 +347,8 @@ def run(self) -> None: # noqa: C901, PLR0915 is met. """ _set_workers_neps_state(self.state) + signal.signal(signal.SIGINT, self._emergency_cleanup) + signal.signal(signal.SIGTERM, self._emergency_cleanup) logger.info("Launching NePS") @@ -416,15 +421,20 @@ def run(self) -> None: # noqa: C901, PLR0915 continue # We (this worker) has managed to set it to evaluating, now we can evaluate it - with _set_global_trial(trial_to_eval): - evaluated_trial, report = evaluate_trial( - trial=trial_to_eval, - evaluation_fn=self.evaluation_fn, - default_report_values=self.settings.default_report_values, - ) - evaluation_duration = evaluated_trial.metadata.evaluation_duration - assert evaluation_duration is not None - self.worker_cumulative_evaluation_time_seconds += evaluation_duration + try: + with _set_global_trial(trial_to_eval): + evaluated_trial, report = evaluate_trial( + trial=trial_to_eval, + evaluation_fn=self.evaluation_fn, + default_report_values=self.settings.default_report_values, + ) + except KeyboardInterrupt: + if not self._SIGNAL_HANDLER_FIRED: + self._emergency_cleanup(signum=signal.SIGINT, frame=None) + + evaluation_duration = evaluated_trial.metadata.evaluation_duration + assert evaluation_duration is not None + self.worker_cumulative_evaluation_time_seconds += evaluation_duration self.worker_cumulative_eval_count += 1 @@ -460,6 +470,31 @@ def run(self) -> None: # noqa: C901, PLR0915 "Learning Curve %s: %s", evaluated_trial.id, report.learning_curve ) + def _emergency_cleanup(self, signum: int, frame: Any) -> None: # noqa: ARG002 + """Handle signals.""" + self._SIGNAL_HANDLER_FIRED = True + + global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 + logger.info( + f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!" + ) + if _CURRENTLY_RUNNING_TRIAL_IN_PROCESS is not None: + logger.info( + "Worker '%s' was interrupted while evaluating trial: %s. Setting" + " trial to pending!", + self.worker_id, + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS.id, + ) + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS.reset() + try: + self.state.put_updated_trial(_CURRENTLY_RUNNING_TRIAL_IN_PROCESS) + except NePSError as e: + logger.exception(e) + finally: + _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None + + raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.") + # TODO: This should be done directly in `api.run` at some point to make it clearer at an # entryy point how the woerer is set up to run if someone reads the entry point code. diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 5e819448..c85ac651 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -1,10 +1,14 @@ from __future__ import annotations import pytest +import os from dataclasses import dataclass from pandas.core.common import contextlib +import signal from pathlib import Path from pytest_cases import fixture, parametrize +from multiprocessing import Process +import time from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker @@ -198,3 +202,63 @@ def __call__(self, *args, **kwargs) -> float: assert neps_state.get_next_pending_trial() is None assert len(neps_state.get_errors()) == 1 + + +def sleep_function(*args, **kwargs) -> float: + time.sleep(10) + return 10 + + +SIGNALS: list[signal.Signals] = [] +for name in ("SIGINT", "SIGTERM", "CTRL_C_EVENT"): + if hasattr(signal.Signals, name): + sig: signal.Signals = getattr(signal.Signals, name) + SIGNALS.append(sig) + + +@pytest.mark.parametrize("signal", SIGNALS) +def test_worker_reset_evaluating_to_pending_on_ctrl_c( + signal: signal.Signals, + neps_state: NePSState, +) -> None: + optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1))) + settings = WorkerSettings( + on_error=OnErrorPossibilities.IGNORE, # <- Highlight + default_report_values=DefaultReportValues(), + max_evaluations_total=None, + include_in_progress_evaluations_towards_maximum=False, + max_cost_total=None, + max_evaluations_for_worker=1, + max_evaluation_time_total_seconds=None, + max_wallclock_time_for_worker_seconds=None, + max_evaluation_time_for_worker_seconds=None, + max_cost_for_worker=None, + ) + + worker1 = DefaultWorker.new( + state=neps_state, + optimizer=optimizer, + evaluation_fn=sleep_function, + settings=settings, + _pre_sample_hooks=None, + ) + + p = Process(target=worker1.run) + p.start() + + time.sleep(0.5) + assert p.pid is not None + assert p.is_alive() + + # Should be evaluating at this stage + trials = neps_state.get_all_trials() + assert len(trials) == 1 + assert next(iter(trials.values())).state == Trial.State.EVALUATING + + # Kill the process while it's evaluting using signals + os.kill(p.pid, signal) + p.join() + + trials2 = neps_state.get_all_trials() + assert len(trials2) == 1 + assert next(iter(trials2.values())).state == Trial.State.PENDING From 0abddd9a1fc8b245ccfb19b11895ca6d5dada72a Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 5 Aug 2024 20:29:26 +0200 Subject: [PATCH 02/14] fix: Rethrow orgiinal keyboard interupt if caught --- neps/runtime.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index ec9b6843..90a94249 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -428,9 +428,10 @@ def run(self) -> None: # noqa: C901, PLR0915 evaluation_fn=self.evaluation_fn, default_report_values=self.settings.default_report_values, ) - except KeyboardInterrupt: + except KeyboardInterrupt as e: if not self._SIGNAL_HANDLER_FIRED: - self._emergency_cleanup(signum=signal.SIGINT, frame=None) + # This throws and we have stopped the worker at this point + self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) evaluation_duration = evaluated_trial.metadata.evaluation_duration assert evaluation_duration is not None @@ -470,7 +471,12 @@ def run(self) -> None: # noqa: C901, PLR0915 "Learning Curve %s: %s", evaluated_trial.id, report.learning_curve ) - def _emergency_cleanup(self, signum: int, frame: Any) -> None: # noqa: ARG002 + def _emergency_cleanup( + self, + signum: int, + frame: Any, # noqa: ARG002 + rethrow: KeyboardInterrupt | None = None, + ) -> None: """Handle signals.""" self._SIGNAL_HANDLER_FIRED = True @@ -493,6 +499,8 @@ def _emergency_cleanup(self, signum: int, frame: Any) -> None: # noqa: ARG002 finally: _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None + if rethrow is not None: + raise rethrow raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.") From d2ced6c55251eb839b10c48dbfe1887420ed2854 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 5 Aug 2024 20:30:29 +0200 Subject: [PATCH 03/14] fix: Change `.info` log messages to `.error` --- neps/runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 90a94249..64fc5a49 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -481,11 +481,11 @@ def _emergency_cleanup( self._SIGNAL_HANDLER_FIRED = True global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 - logger.info( + logger.error( f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!" ) if _CURRENTLY_RUNNING_TRIAL_IN_PROCESS is not None: - logger.info( + logger.error( "Worker '%s' was interrupted while evaluating trial: %s. Setting" " trial to pending!", self.worker_id, From ec9d489ebd907ce2beeb8e49ef0ea7d71c0a482c Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 11:30:00 +0200 Subject: [PATCH 04/14] fix: Longer sleep and `CTRL_C_EVENT` handling for Windows --- neps/runtime.py | 17 ++++++++++++++--- .../test_error_handling_strategies.py | 6 +++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 64fc5a49..7981875b 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -46,6 +46,13 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +SIGNALS_TO_HANDLE_IF_AVAILABLE = [ + "SIGINT", + "SIGTERM", + "CTRL_C_EVENT", +] + + N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 10 N_FAILED_TO_SET_TRIAL_STATE = 10 @@ -340,15 +347,19 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return False - def run(self) -> None: # noqa: C901, PLR0915 + def run(self) -> None: # noqa: C901, PLR0915, PLR0912 """Run the worker. Will keep running until one of the criterion defined by the `WorkerSettings` is met. """ + for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: + if hasattr(signal.Signals, name): + sig = getattr(signal.Signals, name) + signal.signal(sig, self._emergency_cleanup) + signal.signal(sig, self._emergency_cleanup) + _set_workers_neps_state(self.state) - signal.signal(signal.SIGINT, self._emergency_cleanup) - signal.signal(signal.SIGTERM, self._emergency_cleanup) logger.info("Launching NePS") diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index c85ac651..cb086b52 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -246,7 +246,11 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( p = Process(target=worker1.run) p.start() - time.sleep(0.5) + # NOTE: Unfortunatly we have to wait a rather long time as windows does not + # have fork/fork-server available and it must start a new process and re-import + # everything, with which `torch`, taskes a long time + time.sleep(3) + assert p.pid is not None assert p.is_alive() From 5b9216dac4421f2ad9a5267b370d804d96f8f542 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 11:53:06 +0200 Subject: [PATCH 05/14] fix: More robustness around signals --- neps/runtime.py | 21 ++++++++++++------- .../test_error_handling_strategies.py | 13 +++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 7981875b..b8ce54db 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -23,6 +23,8 @@ TypeVar, ) +from pandas.core.common import contextlib + from neps.exceptions import ( NePSError, VersionMismatchError, @@ -50,6 +52,7 @@ def _default_worker_name() -> str: "SIGINT", "SIGTERM", "CTRL_C_EVENT", + "CTRL_BREAK_EVENT", ] @@ -347,18 +350,22 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return False - def run(self) -> None: # noqa: C901, PLR0915, PLR0912 + def _set_signal_handlers(self) -> None: + for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: + if hasattr(signal.Signals, name): + sig = getattr(signal.Signals, name) + # HACK: Despite what python documentation says, the existance of a signal + # is not enough to guarantee that it can be caught. + with contextlib.suppress(ValueError): + signal.signal(sig, self._emergency_cleanup) + + def run(self) -> None: # noqa: C901, PLR0915 """Run the worker. Will keep running until one of the criterion defined by the `WorkerSettings` is met. """ - for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: - if hasattr(signal.Signals, name): - sig = getattr(signal.Signals, name) - signal.signal(sig, self._emergency_cleanup) - signal.signal(sig, self._emergency_cleanup) - + self._set_signal_handlers() _set_workers_neps_state(self.state) logger.info("Launching NePS") diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index cb086b52..ce9f7731 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -11,7 +11,7 @@ import time from neps.optimizers.random_search.optimizer import RandomSearch -from neps.runtime import DefaultWorker +from neps.runtime import DefaultWorker, SIGNALS_TO_HANDLE_IF_AVAILABLE from neps.search_spaces.search_space import SearchSpace from neps.state.err_dump import SerializedError from neps.state.filebased import create_or_load_filebased_neps_state @@ -210,15 +210,16 @@ def sleep_function(*args, **kwargs) -> float: SIGNALS: list[signal.Signals] = [] -for name in ("SIGINT", "SIGTERM", "CTRL_C_EVENT"): +for name in SIGNALS_TO_HANDLE_IF_AVAILABLE: if hasattr(signal.Signals, name): sig: signal.Signals = getattr(signal.Signals, name) SIGNALS.append(sig) +@pytest.mark.ci_examples @pytest.mark.parametrize("signal", SIGNALS) def test_worker_reset_evaluating_to_pending_on_ctrl_c( - signal: signal.Signals, + signum: signal.Signals, neps_state: NePSState, ) -> None: optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1))) @@ -249,7 +250,9 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( # NOTE: Unfortunatly we have to wait a rather long time as windows does not # have fork/fork-server available and it must start a new process and re-import # everything, with which `torch`, taskes a long time - time.sleep(3) + # Also seems to happen sporadically on mac... let's bump it to 5, sorry + # I changed it to only run on `ci_examples` tag then. + time.sleep(5) assert p.pid is not None assert p.is_alive() @@ -260,7 +263,7 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( assert next(iter(trials.values())).state == Trial.State.EVALUATING # Kill the process while it's evaluting using signals - os.kill(p.pid, signal) + signal.raise_signal(p.pid) p.join() trials2 = neps_state.get_all_trials() From acd19bca141178b8839c23344436cbed85b9ff17 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 11:59:46 +0200 Subject: [PATCH 06/14] test: Use `os.kill()` --- tests/test_runtime/test_error_handling_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index ce9f7731..a45c1e9e 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -263,7 +263,7 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( assert next(iter(trials.values())).state == Trial.State.EVALUATING # Kill the process while it's evaluting using signals - signal.raise_signal(p.pid) + os.kill(p.pid, signum) p.join() trials2 = neps_state.get_all_trials() From 32b9b35465cc6623e8bba2ed8667305be1772bf5 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 12:00:11 +0200 Subject: [PATCH 07/14] test: Remove unused parametrization --- tests/test_runtime/test_error_handling_strategies.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index a45c1e9e..04fb3efe 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -130,14 +130,7 @@ def evaler(*args, **kwargs) -> float: assert len(neps_state.get_errors()) == 1 -@pytest.mark.parametrize( - "on_error", - [OnErrorPossibilities.IGNORE, OnErrorPossibilities.RAISE_WORKER_ERROR], -) -def test_worker_does_not_raise_when_error_in_other_worker( - neps_state: NePSState, - on_error: OnErrorPossibilities, -) -> None: +def test_worker_does_not_raise_when_error_in_other_worker(neps_state: NePSState) -> None: optimizer = RandomSearch(pipeline_space=SearchSpace(a=FloatParameter(0, 1))) settings = WorkerSettings( on_error=OnErrorPossibilities.RAISE_WORKER_ERROR, # <- Highlight From f5a4dafc90268bb685e8e0ab45b50ab651476e12 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 13:04:03 +0200 Subject: [PATCH 08/14] test: Pass in correct parameter name --- tests/test_runtime/test_error_handling_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 04fb3efe..5ea4250c 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -210,7 +210,7 @@ def sleep_function(*args, **kwargs) -> float: @pytest.mark.ci_examples -@pytest.mark.parametrize("signal", SIGNALS) +@pytest.mark.parametrize("signum", SIGNALS) def test_worker_reset_evaluating_to_pending_on_ctrl_c( signum: signal.Signals, neps_state: NePSState, From 449e38a8f48fa3bb26d6310a27ad0f91c67fc85a Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 13:26:27 +0200 Subject: [PATCH 09/14] test: Use `psutil` for sending the signal --- pyproject.toml | 1 + .../test_error_handling_strategies.py | 34 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 06b4baa4..400aaec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ pre-commit = "^3" mypy = "^1" pytest = "^7" pytest-cases = "^3" +psutil = "^6" types-PyYAML = "^6" mkdocs-material = "*" mkdocs-autorefs = "*" diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 5ea4250c..bc43e5a2 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -7,7 +7,8 @@ import signal from pathlib import Path from pytest_cases import fixture, parametrize -from multiprocessing import Process +import multiprocessing +import psutil import time from neps.optimizers.random_search.optimizer import RandomSearch @@ -237,16 +238,13 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( _pre_sample_hooks=None, ) - p = Process(target=worker1.run) + # Use multiprocessing.Process + p = multiprocessing.Process( + target=worker1.run, args=(neps_state, optimizer, settings) + ) p.start() - # NOTE: Unfortunatly we have to wait a rather long time as windows does not - # have fork/fork-server available and it must start a new process and re-import - # everything, with which `torch`, taskes a long time - # Also seems to happen sporadically on mac... let's bump it to 5, sorry - # I changed it to only run on `ci_examples` tag then. time.sleep(5) - assert p.pid is not None assert p.is_alive() @@ -255,10 +253,16 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( assert len(trials) == 1 assert next(iter(trials.values())).state == Trial.State.EVALUATING - # Kill the process while it's evaluting using signals - os.kill(p.pid, signum) - p.join() - - trials2 = neps_state.get_all_trials() - assert len(trials2) == 1 - assert next(iter(trials2.values())).state == Trial.State.PENDING + # Kill the process while it's evaluating using signals + process = psutil.Process(p.pid) + process.send_signal(signum) + p.join(timeout=10) # Wait for the process to terminate + + if p.is_alive(): + p.terminate() # Force terminate if it's still alive + p.join() + pytest.fail("Worker did not terminate after receiving signal!") + else: + trials2 = neps_state.get_all_trials() + assert len(trials2) == 1 + assert next(iter(trials2.values())).state == Trial.State.PENDING From d40ad7c12a7515fa2a431b16ffab5e9140832f15 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 13:39:07 +0200 Subject: [PATCH 10/14] fix: Remove args from `worker.run` --- tests/test_runtime/test_error_handling_strategies.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index bc43e5a2..409ca210 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -239,9 +239,7 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( ) # Use multiprocessing.Process - p = multiprocessing.Process( - target=worker1.run, args=(neps_state, optimizer, settings) - ) + p = multiprocessing.Process(target=worker1.run) p.start() time.sleep(5) From 5d8955b6be8112dae3c8c4f05d66413302568b7b Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 13:58:17 +0200 Subject: [PATCH 11/14] fix: Don't handle `CTRL_BREAK_EVENT` --- neps/runtime.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neps/runtime.py b/neps/runtime.py index b8ce54db..0baa4b7a 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -52,7 +52,6 @@ def _default_worker_name() -> str: "SIGINT", "SIGTERM", "CTRL_C_EVENT", - "CTRL_BREAK_EVENT", ] From 6afd09d8be7fa21d9c6940cca852d63b22d80cb1 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 16:23:55 +0200 Subject: [PATCH 12/14] test: Try 10 seconds for windows `spawn` process --- tests/test_runtime/test_error_handling_strategies.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 409ca210..bf18efdb 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -199,7 +199,7 @@ def __call__(self, *args, **kwargs) -> float: def sleep_function(*args, **kwargs) -> float: - time.sleep(10) + time.sleep(20) return 10 @@ -242,7 +242,11 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( p = multiprocessing.Process(target=worker1.run) p.start() - time.sleep(5) + # Windows is exceptionally slow at starting processes + # due to it's spawn and the fact we import torch freshly in + # the worker... hence we give it 10 seconds to get there and + # only run this test in CI + time.sleep(10) assert p.pid is not None assert p.is_alive() From 812b84e0db14672e5edef91642f07bdaca5cb93d Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 16:31:50 +0200 Subject: [PATCH 13/14] fix: Don't handle CTRL_C_EVENT, call previous signal handler --- neps/runtime.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 0baa4b7a..cd733dc8 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -9,7 +9,7 @@ import signal import time from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import ( TYPE_CHECKING, @@ -51,7 +51,6 @@ def _default_worker_name() -> str: SIGNALS_TO_HANDLE_IF_AVAILABLE = [ "SIGINT", "SIGTERM", - "CTRL_C_EVENT", ] @@ -155,7 +154,7 @@ class DefaultWorker(Generic[Loc]): worker_cumulative_evaluation_time_seconds: float = 0.0 """The time spent evaluating configurations by this worker.""" - _SIGNAL_HANDLER_FIRED: bool = False + _PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict) @classmethod def new( @@ -356,7 +355,8 @@ def _set_signal_handlers(self) -> None: # HACK: Despite what python documentation says, the existance of a signal # is not enough to guarantee that it can be caught. with contextlib.suppress(ValueError): - signal.signal(sig, self._emergency_cleanup) + previous_signal_handler = signal.signal(sig, self._emergency_cleanup) + self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler def run(self) -> None: # noqa: C901, PLR0915 """Run the worker. @@ -446,9 +446,8 @@ def run(self) -> None: # noqa: C901, PLR0915 default_report_values=self.settings.default_report_values, ) except KeyboardInterrupt as e: - if not self._SIGNAL_HANDLER_FIRED: - # This throws and we have stopped the worker at this point - self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) + # This throws and we have stopped the worker at this point + self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) evaluation_duration = evaluated_trial.metadata.evaluation_duration assert evaluation_duration is not None @@ -491,12 +490,10 @@ def run(self) -> None: # noqa: C901, PLR0915 def _emergency_cleanup( self, signum: int, - frame: Any, # noqa: ARG002 + frame: Any, rethrow: KeyboardInterrupt | None = None, ) -> None: """Handle signals.""" - self._SIGNAL_HANDLER_FIRED = True - global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 logger.error( f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!" @@ -516,6 +513,9 @@ def _emergency_cleanup( finally: _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None + previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum) + if previous_handler is not None and callable(previous_handler): + previous_handler(signum, frame) if rethrow is not None: raise rethrow raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.") From 0c6c785def325b8338bb31e71ee3d2d22ebbd77c Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 16:32:17 +0200 Subject: [PATCH 14/14] test: Reduce sleep time --- neps/runtime.py | 1 + tests/test_runtime/test_error_handling_strategies.py | 6 +----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index cd733dc8..c7298530 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -448,6 +448,7 @@ def run(self) -> None: # noqa: C901, PLR0915 except KeyboardInterrupt as e: # This throws and we have stopped the worker at this point self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) + return evaluation_duration = evaluated_trial.metadata.evaluation_duration assert evaluation_duration is not None diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index bf18efdb..d357ec90 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -242,11 +242,7 @@ def test_worker_reset_evaluating_to_pending_on_ctrl_c( p = multiprocessing.Process(target=worker1.run) p.start() - # Windows is exceptionally slow at starting processes - # due to it's spawn and the fact we import torch freshly in - # the worker... hence we give it 10 seconds to get there and - # only run this test in CI - time.sleep(10) + time.sleep(5) assert p.pid is not None assert p.is_alive()