Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@ class EarlyStopping(Serializable):
Args:
patience: Number of events to wait if no improvement and then stop the training.
score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine`
object, and return a score `float`. An improvement is considered if the score is higher.
object, and return a score `float`. An improvement is considered if the score is higher (for ``mode='max'``)
or lower (for ``mode='min'``).
trainer: Trainer engine to stop the run if no improvement.
min_delta: A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to the minimum delta threshold (as determined by min_delta and min_delta_mode), will count as no improvement.
cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event. Default value is False.
min_delta_mode: Determine whether `min_delta` is an absolute increase or a relative increase.
In 'abs' mode, the threshold is min_delta,
i.e. an increase of less than or equal to min_delta, will count as no improvement.
In 'rel' mode, the threshold is abs(best_score) * min_delta,
i.e. an increase of less than or equal to abs(best_score) * min_delta, will count as no improvement.
min_delta: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum
increase; for ``mode='min'``, it's a minimum decrease. An improvement is only considered if the change
exceeds the threshold determined by `min_delta` and `min_delta_mode`.
cumulative_delta: If True, `min_delta` defines the change since the last `patience` reset, otherwise,
it defines the change after the last event. Default value is False.
min_delta_mode: Determines whether `min_delta` is an absolute change or a relative change.

- In 'abs' mode:

- For ``mode='max'``: improvement if score > best_score + min_delta
- For ``mode='min'``: improvement if score < best_score - min_delta

- In 'rel' mode:

- For ``mode='max'``: improvement if score > best_score * (1 + min_delta)
- For ``mode='min'``: improvement if score < best_score * (1 - min_delta)

Possible values are "abs" and "rel". Default value is "abs".
mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'.

Examples:
.. code-block:: python
Expand All @@ -41,6 +51,10 @@ def score_function(engine):
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)

.. versionchanged:: 0.6.0
Added `mode` parameter to support minimization in addition to maximization.
Added `min_delta_mode` parameter to support both absolute and relative improvements.

"""

_state_dict_all_req_keys = (
Expand All @@ -56,6 +70,7 @@ def __init__(
min_delta: float = 0.0,
cumulative_delta: bool = False,
min_delta_mode: Literal["abs", "rel"] = "abs",
mode: Literal["min", "max"] = "max",
):
if not callable(score_function):
raise TypeError("Argument score_function should be a function.")
Expand All @@ -72,6 +87,9 @@ def __init__(
if min_delta_mode not in ("abs", "rel"):
raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.")

if mode not in ("min", "max"):
raise ValueError("Argument mode should be either 'min' or 'max'.")

self.score_function = score_function
self.patience = patience
self.min_delta = min_delta
Expand All @@ -81,21 +99,26 @@ def __init__(
self.best_score: Optional[float] = None
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
self.min_delta_mode = min_delta_mode
self.mode = mode

def __call__(self, engine: Engine) -> None:
score = self.score_function(engine)

if self.best_score is None:
self.best_score = score
return
upper_bound = (
self.best_score + self.min_delta
if self.min_delta_mode == "abs"
else self.best_score + abs(self.best_score) * self.min_delta
)
if score <= upper_bound:
if not self.cumulative_delta and score > self.best_score:
self.best_score = score

min_delta = -self.min_delta if self.mode == "min" else self.min_delta
if self.min_delta_mode == "abs":
improvement_threshold = self.best_score + min_delta
else:
improvement_threshold = self.best_score * (1 + min_delta)

no_improvement = score <= improvement_threshold if self.mode == "max" else score >= improvement_threshold

if no_improvement:
if not self.cumulative_delta:
self.best_score = max(score, self.best_score) if self.mode == "max" else min(score, self.best_score)
self.counter += 1
self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
if self.counter >= self.patience:
Expand Down
90 changes: 90 additions & 0 deletions tests/ignite/handlers/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def test_args_validation():
with pytest.raises(ValueError, match=r"Argument min_delta_mode should be either 'abs' or 'rel'."):
EarlyStopping(patience=2, min_delta_mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer)

with pytest.raises(ValueError, match=r"Argument mode should be either 'min' or 'max'."):
EarlyStopping(patience=2, mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer)


def test_simple_early_stopping():
scores = iter([1.0, 0.8, 0.88])
Expand Down Expand Up @@ -296,6 +299,93 @@ def evaluation(engine):
assert trainer.state.epoch == 10


def test_simple_early_stopping_min_mode():
scores = iter([1.0, 1.2, 0.9])

def score_function(engine):
return next(scores)

trainer = Engine(do_nothing_update_fn)

h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer, mode="min")
# Call 3 times and check if stopped
assert not trainer.should_terminate
h(None) # best_score=1.0
assert not trainer.should_terminate
h(None) # score=1.2 (no improvement)
assert not trainer.should_terminate
h(None) # score=0.9 (improvement)
assert not trainer.should_terminate


def test_early_stopping_min_mode_with_delta():
scores = iter([1.1, 0.95, 0.94, 0.93])

trainer = Engine(do_nothing_update_fn)

h = EarlyStopping(patience=2, min_delta=0.1, score_function=lambda _: next(scores), trainer=trainer, mode="min")

assert not trainer.should_terminate
h(None) # best_score=1.1
assert not trainer.should_terminate
h(None) # score=0.95 (improvement: 0.95 < 1.1 - 0.1 = 1.0)
assert not trainer.should_terminate
h(None) # score=0.94 (no improvement: 0.94 >= 0.95 - 0.1 = 0.85)
assert not trainer.should_terminate
h(None) # score=0.93 (no improvement: 0.93 >= 0.95 - 0.1 = 0.85)
assert trainer.should_terminate


def test_early_stopping_min_mode_with_delta_cumulative():
scores = iter([1.1, 0.95, 0.94, 0.93])

trainer = Engine(do_nothing_update_fn)

h = EarlyStopping(
patience=2,
min_delta=0.1,
score_function=lambda _: next(scores),
trainer=trainer,
cumulative_delta=True,
mode="min",
)

assert not trainer.should_terminate
h(None) # best_score=1.1
assert not trainer.should_terminate
h(None) # score=0.95 (improvement: 0.95 < 1.1 - 0.1 = 1.0)
assert not trainer.should_terminate
h(None) # score=0.94 (no improvement: 0.94 >= 0.95 - 0.1 = 0.85)
assert not trainer.should_terminate
h(None) # score=0.93 (no improvement: 0.93 >= 0.94 - 0.1 = 0.84)
assert trainer.should_terminate


def test_early_stopping_min_mode_rel_delta():
scores = iter([1.0, 0.8, 0.79, 0.78])

trainer = Engine(do_nothing_update_fn)

h = EarlyStopping(
patience=2,
min_delta=0.1,
min_delta_mode="rel",
score_function=lambda _: next(scores),
trainer=trainer,
mode="min",
)

assert not trainer.should_terminate
h(None) # best_score=1.0
assert not trainer.should_terminate
h(None) # score=0.8 (improvement: 0.8 < 1.0 * (1 - 0.1) = 0.9)
assert not trainer.should_terminate
h(None) # score=0.79 (no improvement: 0.79 >= 0.8 * (1 - 0.1) = 0.72)
assert not trainer.should_terminate
h(None) # score=0.78 (no improvement)
assert trainer.should_terminate


def _test_distrib_with_engine_early_stopping(device):
if device is None:
device = idist.device()
Expand Down
Loading