From b636374108e425f86c8e050528d8fd240694806e Mon Sep 17 00:00:00 2001 From: Fabio Bonassi Date: Mon, 9 Dec 2024 14:20:29 +0100 Subject: [PATCH] Allow to terminate an epoch without firing `Events.EPOCH_COMPLETED` (#3313) * Added optional flag skip_epoch_completed to Engine.terminate_epoch() * Improved docs for terminate() and terminate_epoch() * Make the internal attribute skip_completed_after_termination private * - Merged flags "should_terminate" and "_skip_completed_after_termination". - Merged flags "should_terminate_single_epoch" and "_skip_epoch_completed_after_termination". * Union[bool, str] instead of the pipe operator for compatibility with older Python versions * Raise an RuntimeError when terminate_epoch() is called on Events.STARTED or Events.EPOCH_STARTED * Ignoring comparison-overlap warning from mypy to keep the code simple * Apply suggestions from code review * Update engine.py --------- Co-authored-by: vfdev --- ignite/engine/engine.py | 79 +++++++++++++++++++++--------- ignite/engine/events.py | 12 +++-- tests/ignite/engine/test_engine.py | 71 ++++++++++++++++++++------- 3 files changed, 119 insertions(+), 43 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index e2a14898607..f3f95c9a2e2 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -139,9 +139,12 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._process_function = process_function self.last_event_name: Optional[Events] = None - self.should_terminate = False - self.skip_completed_after_termination = False - self.should_terminate_single_epoch = False + # should_terminate flag: False - don't terminate, True - terminate, + # "skip_completed" - terminate and skip the event "COMPLETED" + self.should_terminate: Union[bool, str] = False + # should_terminate_single_epoch flag: False - don't terminate, True - terminate, + # "skip_epoch_completed" - terminate and skip the event "EPOCH_COMPLETED" + self.should_terminate_single_epoch: Union[bool, str] = False self.should_interrupt = False self.state = State() self._state_dict_user_keys: List[str] = [] @@ -546,7 +549,7 @@ def terminate(self, skip_completed: bool = False) -> None: - ... - Terminating event - :attr:`~ignite.engine.events.Events.TERMINATE` - - :attr:`~ignite.engine.events.Events.COMPLETED` + - :attr:`~ignite.engine.events.Events.COMPLETED` (unless `skip_completed=True`) Args: skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after @@ -625,25 +628,31 @@ def terminate(): Added `skip_completed` flag """ self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") - self.should_terminate = True - self.skip_completed_after_termination = skip_completed + self.should_terminate = "skip_completed" if skip_completed else True - def terminate_epoch(self) -> None: + def terminate_epoch(self, skip_epoch_completed: bool = False) -> None: """Sends terminate signal to the engine, so that it terminates the current epoch. The run continues from the next epoch. The following events are triggered: - ... - Event on which ``terminate_epoch`` method is called - :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH` - - :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` + - :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` (unless `skip_epoch_completed=True`) - :attr:`~ignite.engine.events.Events.EPOCH_STARTED` - ... + + Args: + skip_epoch_completed: if True, the event :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` + is not fired after :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`. Default is False. + + .. versionchanged:: 0.5.2 + Added `skip_epoch_completed` flag """ self.logger.info( "Terminate current epoch is signaled. " "Current epoch iteration will stop after current iteration is finished." ) - self.should_terminate_single_epoch = True + self.should_terminate_single_epoch = "skip_epoch_completed" if skip_epoch_completed else True def _handle_exception(self, e: BaseException) -> None: if Events.EXCEPTION_RAISED in self._event_handlers: @@ -982,11 +991,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: # time is available for handlers but must be updated after fire self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - handlers_start_time = time.time() - self._fire_event(Events.EPOCH_COMPLETED) - epoch_time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap] + handlers_start_time = time.time() + self._fire_event(Events.EPOCH_COMPLETED) + epoch_time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + + if self.should_terminate_single_epoch: + # We skip raising _EngineTerminateSingleEpochException exception on Events.EPOCH_COMPLETED + # as epoch is already completed and nothing to terminate + self.should_terminate_single_epoch = False yield from self._maybe_terminate_or_interrupt() hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) @@ -997,12 +1012,19 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: except _EngineTerminateException: self._fire_event(Events.TERMINATE) + except _EngineTerminateSingleEpochException: + raise RuntimeError( + "The method terminate_epoch() should not be called on Event.STARTED or Event.EPOCH_STARTED." + "If this is a desired behaviour, please open a feature request on" + "https://github.com/pytorch/ignite/issues/new/choose" + ) + time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` - if not (self.should_terminate and self.skip_completed_after_termination): + if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap] handlers_start_time = time.time() self._fire_event(Events.COMPLETED) time_taken += time.time() - handlers_start_time @@ -1121,7 +1143,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) - self.should_terminate_single_epoch = False self._setup_dataloader_iter() except _EngineTerminateException as e: @@ -1167,11 +1188,17 @@ def _internal_run_legacy(self) -> State: # time is available for handlers but must be updated after fire self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - handlers_start_time = time.time() - self._fire_event(Events.EPOCH_COMPLETED) - epoch_time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap] + handlers_start_time = time.time() + self._fire_event(Events.EPOCH_COMPLETED) + epoch_time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + + if self.should_terminate_single_epoch: + # We skip raising _EngineTerminateSingleEpochException exception on Events.EPOCH_COMPLETED + # as epoch is already completed and nothing to terminate + self.should_terminate_single_epoch = False self._maybe_terminate_legacy() hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) @@ -1182,12 +1209,19 @@ def _internal_run_legacy(self) -> State: except _EngineTerminateException: self._fire_event(Events.TERMINATE) + except _EngineTerminateSingleEpochException: + raise RuntimeError( + "The method terminate_epoch() should not be called on Event.STARTED or Event.EPOCH_STARTED." + "If this is a desired behaviour, please open a feature request on" + "https://github.com/pytorch/ignite/issues/new/choose" + ) + time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` - if not (self.should_terminate and self.skip_completed_after_termination): + if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap] handlers_start_time = time.time() self._fire_event(Events.COMPLETED) time_taken += time.time() - handlers_start_time @@ -1292,7 +1326,6 @@ def _run_once_on_dataset_legacy(self) -> float: except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) - self.should_terminate_single_epoch = False self._setup_dataloader_iter() except _EngineTerminateException as e: diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 87622d3415c..7a348f94762 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -259,8 +259,9 @@ class Events(EventEnum): - TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch, after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or :meth:`~ignite.engine.engine.Engine.terminate()` call. - - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even - when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called. + - EPOCH_COMPLETED : triggered when the epoch is ended. This is triggered even + when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called, + unless the flag `skip_epoch_completed` is set to True. - TERMINATE : triggered when the run is about to end completely, after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call. @@ -272,7 +273,7 @@ class Events(EventEnum): The table below illustrates which events are triggered when various termination methods are called. .. list-table:: - :widths: 35 38 28 20 20 + :widths: 38 38 28 20 20 :header-rows: 1 * - Method @@ -290,6 +291,11 @@ class Events(EventEnum): - ✔ - ✗ - ✔ + * - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` with `skip_epoch_completed=True` + - ✔ + - ✗ + - ✗ + - ✔ * - :meth:`~ignite.engine.engine.Engine.terminate()` - ✗ - ✔ diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index fcb0299aa22..76e1ad83760 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -44,10 +44,13 @@ def set_interrupt_resume_enabled(self, interrupt_resume_enabled): def test_terminate(self, skip_completed): engine = Engine(lambda e, b: 1) assert not engine.should_terminate - assert not engine.skip_completed_after_termination + engine.terminate(skip_completed) - assert engine.should_terminate - assert engine.skip_completed_after_termination == skip_completed + + if skip_completed: + assert engine.should_terminate == "skip_completed" + else: + assert engine.should_terminate == True # noqa: E712 def test_invalid_process_raises_with_invalid_signature(self): with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"): @@ -292,8 +295,11 @@ def assert_no_exceptions(ee): assert engine.called_events[0] == (0, 0, Events.STARTED) assert engine._dataloader_iter is None - @pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) - def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length): + @pytest.mark.parametrize( + "data, epoch_length, skip_epoch_completed", + [(None, 10, False), (range(10), None, False), (None, 10, True), (range(10), None, True)], + ) + def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length, skip_epoch_completed): real_epoch_length = epoch_length if data is None else len(data) iteration_to_stop = real_epoch_length + 4 @@ -301,7 +307,7 @@ def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length): def start_of_iteration_handler(engine): if engine.state.iteration == iteration_to_stop: - engine.terminate_epoch() + engine.terminate_epoch(skip_epoch_completed) max_epochs = 3 engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) @@ -312,15 +318,23 @@ def start_of_iteration_handler(engine): assert state.epoch == max_epochs @pytest.mark.parametrize( - "terminate_epoch_event, i", + "terminate_epoch_event, i, skip_epoch_completed", [ - (Events.GET_BATCH_STARTED(once=12), 12), - (Events.GET_BATCH_COMPLETED(once=12), 12), - (Events.ITERATION_STARTED(once=14), 14), - (Events.ITERATION_COMPLETED(once=14), 14), + (Events.GET_BATCH_STARTED(once=12), 12, False), + (Events.GET_BATCH_COMPLETED(once=12), 12, False), + (Events.ITERATION_STARTED(once=14), 14, False), + (Events.ITERATION_COMPLETED(once=14), 14, False), + (Events.GET_BATCH_STARTED(once=12), 12, True), + (Events.GET_BATCH_COMPLETED(once=12), 12, True), + (Events.ITERATION_STARTED(once=14), 14, True), + (Events.ITERATION_COMPLETED(once=14), 14, True), + (Events.STARTED, 30, False), + (Events.STARTED, 30, True), + (Events.EPOCH_STARTED(once=2), 10, False), + (Events.EPOCH_STARTED(once=2), 10, True), ], ) - def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i): + def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i, skip_epoch_completed): engine = RecordedEngine(MagicMock(return_value=1)) data = range(10) max_epochs = 3 @@ -331,31 +345,54 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i): @engine.on(terminate_epoch_event) def call_terminate_epoch(): + assert not engine.should_terminate_single_epoch nonlocal call_count if call_count < 1: - engine.terminate_epoch() + engine.terminate_epoch(skip_epoch_completed) + if skip_epoch_completed: + assert engine.should_terminate_single_epoch == "skip_epoch_completed" + else: + assert engine.should_terminate_single_epoch == True # noqa: E712 + call_count += 1 + @engine.on(Events.EPOCH_STARTED) + def check_skip_reset(): + if terminate_epoch_event != Events.EPOCH_STARTED: + assert engine.should_terminate_single_epoch == False # noqa: E712 + @engine.on(Events.TERMINATE_SINGLE_EPOCH) def check_previous_events(iter_counter): e = i // len(data) + 1 - assert engine.called_events[0] == (0, 0, Events.STARTED) assert engine.called_events[-2] == (e, i, terminate_epoch_event) assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH) + if skip_epoch_completed: + assert engine.should_terminate_single_epoch == "skip_epoch_completed" + else: + assert engine.should_terminate_single_epoch == True # noqa: E712 @engine.on(Events.EPOCH_COMPLETED) def check_previous_events2(): e = i // len(data) + 1 if e == engine.state.epoch and i == engine.state.iteration: + assert not skip_epoch_completed + assert isinstance(engine.should_terminate_single_epoch, bool) assert engine.called_events[-3] == (e, i, terminate_epoch_event) assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH) assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED) - engine.run(data, max_epochs=max_epochs) + if terminate_epoch_event in [Events.STARTED, Events.EPOCH_STARTED]: + with pytest.raises(RuntimeError): + engine.run(data, max_epochs=max_epochs) + else: + engine.run(data, max_epochs=max_epochs) + + assert engine.state.epoch == max_epochs + assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data) - assert engine.state.epoch == max_epochs - assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data) + epoch_completed_events = [e for e in engine.called_events if e[2] == Events.EPOCH_COMPLETED.name] + assert len(epoch_completed_events) == max_epochs - skip_epoch_completed @pytest.mark.parametrize("data", [None, "mock_data_loader"]) def test_iteration_events_are_fired(self, data):