Skip to content

Commit

Permalink
Allow to terminate an epoch without firing Events.EPOCH_COMPLETED (#…
Browse files Browse the repository at this point in the history
…3313)

* Added optional flag skip_epoch_completed to Engine.terminate_epoch()

* Improved docs for terminate() and terminate_epoch()

* Make the internal attribute skip_completed_after_termination private

* - Merged flags "should_terminate" and "_skip_completed_after_termination".
- Merged flags "should_terminate_single_epoch" and "_skip_epoch_completed_after_termination".

* Union[bool, str] instead of the pipe operator for compatibility with older Python versions

* Raise an RuntimeError when terminate_epoch() is called on Events.STARTED or Events.EPOCH_STARTED

* Ignoring comparison-overlap warning from mypy to keep the code simple

* Apply suggestions from code review

* Update engine.py

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
bonassifabio and vfdev-5 authored Dec 9, 2024
1 parent 6f8ad2a commit b636374
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 43 deletions.
79 changes: 56 additions & 23 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._process_function = process_function
self.last_event_name: Optional[Events] = None
self.should_terminate = False
self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
# should_terminate flag: False - don't terminate, True - terminate,
# "skip_completed" - terminate and skip the event "COMPLETED"
self.should_terminate: Union[bool, str] = False
# should_terminate_single_epoch flag: False - don't terminate, True - terminate,
# "skip_epoch_completed" - terminate and skip the event "EPOCH_COMPLETED"
self.should_terminate_single_epoch: Union[bool, str] = False
self.should_interrupt = False
self.state = State()
self._state_dict_user_keys: List[str] = []
Expand Down Expand Up @@ -546,7 +549,7 @@ def terminate(self, skip_completed: bool = False) -> None:
- ...
- Terminating event
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
- :attr:`~ignite.engine.events.Events.COMPLETED` (unless `skip_completed=True`)
Args:
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
Expand Down Expand Up @@ -625,25 +628,31 @@ def terminate():
Added `skip_completed` flag
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True
self.skip_completed_after_termination = skip_completed
self.should_terminate = "skip_completed" if skip_completed else True

def terminate_epoch(self) -> None:
def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
continues from the next epoch. The following events are triggered:
- ...
- Event on which ``terminate_epoch`` method is called
- :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` (unless `skip_epoch_completed=True`)
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
- ...
Args:
skip_epoch_completed: if True, the event :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
is not fired after :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`. Default is False.
.. versionchanged:: 0.5.2
Added `skip_epoch_completed` flag
"""
self.logger.info(
"Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished."
)
self.should_terminate_single_epoch = True
self.should_terminate_single_epoch = "skip_epoch_completed" if skip_epoch_completed else True

def _handle_exception(self, e: BaseException) -> None:
if Events.EXCEPTION_RAISED in self._event_handlers:
Expand Down Expand Up @@ -982,11 +991,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

if self.should_terminate_single_epoch:
# We skip raising _EngineTerminateSingleEpochException exception on Events.EPOCH_COMPLETED
# as epoch is already completed and nothing to terminate
self.should_terminate_single_epoch = False
yield from self._maybe_terminate_or_interrupt()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
Expand All @@ -997,12 +1012,19 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

except _EngineTerminateSingleEpochException:
raise RuntimeError(
"The method terminate_epoch() should not be called on Event.STARTED or Event.EPOCH_STARTED."
"If this is a desired behaviour, please open a feature request on"
"https://github.com/pytorch/ignite/issues/new/choose"
)

time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1121,7 +1143,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()

except _EngineTerminateException as e:
Expand Down Expand Up @@ -1167,11 +1188,17 @@ def _internal_run_legacy(self) -> State:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

if self.should_terminate_single_epoch:
# We skip raising _EngineTerminateSingleEpochException exception on Events.EPOCH_COMPLETED
# as epoch is already completed and nothing to terminate
self.should_terminate_single_epoch = False
self._maybe_terminate_legacy()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
Expand All @@ -1182,12 +1209,19 @@ def _internal_run_legacy(self) -> State:
except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

except _EngineTerminateSingleEpochException:
raise RuntimeError(
"The method terminate_epoch() should not be called on Event.STARTED or Event.EPOCH_STARTED."
"If this is a desired behaviour, please open a feature request on"
"https://github.com/pytorch/ignite/issues/new/choose"
)

time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1292,7 +1326,6 @@ def _run_once_on_dataset_legacy(self) -> float:

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()

except _EngineTerminateException as e:
Expand Down
12 changes: 9 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ class Events(EventEnum):
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- EPOCH_COMPLETED : triggered when the epoch is ended. This is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called,
unless the flag `skip_epoch_completed` is set to True.
- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
Expand All @@ -272,7 +273,7 @@ class Events(EventEnum):
The table below illustrates which events are triggered when various termination methods are called.
.. list-table::
:widths: 35 38 28 20 20
:widths: 38 38 28 20 20
:header-rows: 1
* - Method
Expand All @@ -290,6 +291,11 @@ class Events(EventEnum):
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` with `skip_epoch_completed=True`
- ✔
- ✗
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
Expand Down
71 changes: 54 additions & 17 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
def test_terminate(self, skip_completed):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate
assert not engine.skip_completed_after_termination

engine.terminate(skip_completed)
assert engine.should_terminate
assert engine.skip_completed_after_termination == skip_completed

if skip_completed:
assert engine.should_terminate == "skip_completed"
else:
assert engine.should_terminate == True # noqa: E712

def test_invalid_process_raises_with_invalid_signature(self):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
Expand Down Expand Up @@ -292,16 +295,19 @@ def assert_no_exceptions(ee):
assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine._dataloader_iter is None

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length):
@pytest.mark.parametrize(
"data, epoch_length, skip_epoch_completed",
[(None, 10, False), (range(10), None, False), (None, 10, True), (range(10), None, True)],
)
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length, skip_epoch_completed):
real_epoch_length = epoch_length if data is None else len(data)
iteration_to_stop = real_epoch_length + 4

engine = Engine(MagicMock(return_value=1))

def start_of_iteration_handler(engine):
if engine.state.iteration == iteration_to_stop:
engine.terminate_epoch()
engine.terminate_epoch(skip_epoch_completed)

max_epochs = 3
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
Expand All @@ -312,15 +318,23 @@ def start_of_iteration_handler(engine):
assert state.epoch == max_epochs

@pytest.mark.parametrize(
"terminate_epoch_event, i",
"terminate_epoch_event, i, skip_epoch_completed",
[
(Events.GET_BATCH_STARTED(once=12), 12),
(Events.GET_BATCH_COMPLETED(once=12), 12),
(Events.ITERATION_STARTED(once=14), 14),
(Events.ITERATION_COMPLETED(once=14), 14),
(Events.GET_BATCH_STARTED(once=12), 12, False),
(Events.GET_BATCH_COMPLETED(once=12), 12, False),
(Events.ITERATION_STARTED(once=14), 14, False),
(Events.ITERATION_COMPLETED(once=14), 14, False),
(Events.GET_BATCH_STARTED(once=12), 12, True),
(Events.GET_BATCH_COMPLETED(once=12), 12, True),
(Events.ITERATION_STARTED(once=14), 14, True),
(Events.ITERATION_COMPLETED(once=14), 14, True),
(Events.STARTED, 30, False),
(Events.STARTED, 30, True),
(Events.EPOCH_STARTED(once=2), 10, False),
(Events.EPOCH_STARTED(once=2), 10, True),
],
)
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i, skip_epoch_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 3
Expand All @@ -331,31 +345,54 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):

@engine.on(terminate_epoch_event)
def call_terminate_epoch():
assert not engine.should_terminate_single_epoch
nonlocal call_count
if call_count < 1:
engine.terminate_epoch()
engine.terminate_epoch(skip_epoch_completed)
if skip_epoch_completed:
assert engine.should_terminate_single_epoch == "skip_epoch_completed"
else:
assert engine.should_terminate_single_epoch == True # noqa: E712

call_count += 1

@engine.on(Events.EPOCH_STARTED)
def check_skip_reset():
if terminate_epoch_event != Events.EPOCH_STARTED:
assert engine.should_terminate_single_epoch == False # noqa: E712

@engine.on(Events.TERMINATE_SINGLE_EPOCH)
def check_previous_events(iter_counter):
e = i // len(data) + 1

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-2] == (e, i, terminate_epoch_event)
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
if skip_epoch_completed:
assert engine.should_terminate_single_epoch == "skip_epoch_completed"
else:
assert engine.should_terminate_single_epoch == True # noqa: E712

@engine.on(Events.EPOCH_COMPLETED)
def check_previous_events2():
e = i // len(data) + 1
if e == engine.state.epoch and i == engine.state.iteration:
assert not skip_epoch_completed
assert isinstance(engine.should_terminate_single_epoch, bool)
assert engine.called_events[-3] == (e, i, terminate_epoch_event)
assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED)

engine.run(data, max_epochs=max_epochs)
if terminate_epoch_event in [Events.STARTED, Events.EPOCH_STARTED]:
with pytest.raises(RuntimeError):
engine.run(data, max_epochs=max_epochs)
else:
engine.run(data, max_epochs=max_epochs)

assert engine.state.epoch == max_epochs
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)

assert engine.state.epoch == max_epochs
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)
epoch_completed_events = [e for e in engine.called_events if e[2] == Events.EPOCH_COMPLETED.name]
assert len(epoch_completed_events) == max_epochs - skip_epoch_completed

@pytest.mark.parametrize("data", [None, "mock_data_loader"])
def test_iteration_events_are_fired(self, data):
Expand Down

0 comments on commit b636374

Please sign in to comment.