diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index 3da94ceb3857..95e4d49cb571 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -14,18 +14,28 @@ class EarlyStopping(Serializable): Args: patience: Number of events to wait if no improvement and then stop the training. score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` - object, and return a score `float`. An improvement is considered if the score is higher. + object, and return a score `float`. An improvement is considered if the score is higher (for ``mode='max'``) + or lower (for ``mode='min'``). trainer: Trainer engine to stop the run if no improvement. - min_delta: A minimum increase in the score to qualify as an improvement, - i.e. an increase of less than or equal to the minimum delta threshold (as determined by min_delta and min_delta_mode), will count as no improvement. - cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise, - it defines an increase after the last event. Default value is False. - min_delta_mode: Determine whether `min_delta` is an absolute increase or a relative increase. - In 'abs' mode, the threshold is min_delta, - i.e. an increase of less than or equal to min_delta, will count as no improvement. - In 'rel' mode, the threshold is abs(best_score) * min_delta, - i.e. an increase of less than or equal to abs(best_score) * min_delta, will count as no improvement. + min_delta: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum + increase; for ``mode='min'``, it's a minimum decrease. An improvement is only considered if the change + exceeds the threshold determined by `min_delta` and `min_delta_mode`. + cumulative_delta: If True, `min_delta` defines the change since the last `patience` reset, otherwise, + it defines the change after the last event. Default value is False. + min_delta_mode: Determines whether `min_delta` is an absolute change or a relative change. + + - In 'abs' mode: + + - For ``mode='max'``: improvement if score > best_score + min_delta + - For ``mode='min'``: improvement if score < best_score - min_delta + + - In 'rel' mode: + + - For ``mode='max'``: improvement if score > best_score * (1 + min_delta) + - For ``mode='min'``: improvement if score < best_score * (1 - min_delta) + Possible values are "abs" and "rel". Default value is "abs". + mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'. Examples: .. code-block:: python @@ -41,6 +51,10 @@ def score_function(engine): # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). evaluator.add_event_handler(Events.COMPLETED, handler) + .. versionchanged:: 0.6.0 + Added `mode` parameter to support minimization in addition to maximization. + Added `min_delta_mode` parameter to support both absolute and relative improvements. + """ _state_dict_all_req_keys = ( @@ -56,6 +70,7 @@ def __init__( min_delta: float = 0.0, cumulative_delta: bool = False, min_delta_mode: Literal["abs", "rel"] = "abs", + mode: Literal["min", "max"] = "max", ): if not callable(score_function): raise TypeError("Argument score_function should be a function.") @@ -72,6 +87,9 @@ def __init__( if min_delta_mode not in ("abs", "rel"): raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.") + if mode not in ("min", "max"): + raise ValueError("Argument mode should be either 'min' or 'max'.") + self.score_function = score_function self.patience = patience self.min_delta = min_delta @@ -81,6 +99,7 @@ def __init__( self.best_score: Optional[float] = None self.logger = setup_logger(__name__ + "." + self.__class__.__name__) self.min_delta_mode = min_delta_mode + self.mode = mode def __call__(self, engine: Engine) -> None: score = self.score_function(engine) @@ -88,14 +107,18 @@ def __call__(self, engine: Engine) -> None: if self.best_score is None: self.best_score = score return - upper_bound = ( - self.best_score + self.min_delta - if self.min_delta_mode == "abs" - else self.best_score + abs(self.best_score) * self.min_delta - ) - if score <= upper_bound: - if not self.cumulative_delta and score > self.best_score: - self.best_score = score + + min_delta = -self.min_delta if self.mode == "min" else self.min_delta + if self.min_delta_mode == "abs": + improvement_threshold = self.best_score + min_delta + else: + improvement_threshold = self.best_score * (1 + min_delta) + + no_improvement = score <= improvement_threshold if self.mode == "max" else score >= improvement_threshold + + if no_improvement: + if not self.cumulative_delta: + self.best_score = max(score, self.best_score) if self.mode == "max" else min(score, self.best_score) self.counter += 1 self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) if self.counter >= self.patience: diff --git a/tests/ignite/handlers/test_early_stopping.py b/tests/ignite/handlers/test_early_stopping.py index 0dce101331ce..3ea3a1f4d3aa 100644 --- a/tests/ignite/handlers/test_early_stopping.py +++ b/tests/ignite/handlers/test_early_stopping.py @@ -30,6 +30,9 @@ def test_args_validation(): with pytest.raises(ValueError, match=r"Argument min_delta_mode should be either 'abs' or 'rel'."): EarlyStopping(patience=2, min_delta_mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer) + with pytest.raises(ValueError, match=r"Argument mode should be either 'min' or 'max'."): + EarlyStopping(patience=2, mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer) + def test_simple_early_stopping(): scores = iter([1.0, 0.8, 0.88]) @@ -296,6 +299,93 @@ def evaluation(engine): assert trainer.state.epoch == 10 +def test_simple_early_stopping_min_mode(): + scores = iter([1.0, 1.2, 0.9]) + + def score_function(engine): + return next(scores) + + trainer = Engine(do_nothing_update_fn) + + h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer, mode="min") + # Call 3 times and check if stopped + assert not trainer.should_terminate + h(None) # best_score=1.0 + assert not trainer.should_terminate + h(None) # score=1.2 (no improvement) + assert not trainer.should_terminate + h(None) # score=0.9 (improvement) + assert not trainer.should_terminate + + +def test_early_stopping_min_mode_with_delta(): + scores = iter([1.1, 0.95, 0.94, 0.93]) + + trainer = Engine(do_nothing_update_fn) + + h = EarlyStopping(patience=2, min_delta=0.1, score_function=lambda _: next(scores), trainer=trainer, mode="min") + + assert not trainer.should_terminate + h(None) # best_score=1.1 + assert not trainer.should_terminate + h(None) # score=0.95 (improvement: 0.95 < 1.1 - 0.1 = 1.0) + assert not trainer.should_terminate + h(None) # score=0.94 (no improvement: 0.94 >= 0.95 - 0.1 = 0.85) + assert not trainer.should_terminate + h(None) # score=0.93 (no improvement: 0.93 >= 0.95 - 0.1 = 0.85) + assert trainer.should_terminate + + +def test_early_stopping_min_mode_with_delta_cumulative(): + scores = iter([1.1, 0.95, 0.94, 0.93]) + + trainer = Engine(do_nothing_update_fn) + + h = EarlyStopping( + patience=2, + min_delta=0.1, + score_function=lambda _: next(scores), + trainer=trainer, + cumulative_delta=True, + mode="min", + ) + + assert not trainer.should_terminate + h(None) # best_score=1.1 + assert not trainer.should_terminate + h(None) # score=0.95 (improvement: 0.95 < 1.1 - 0.1 = 1.0) + assert not trainer.should_terminate + h(None) # score=0.94 (no improvement: 0.94 >= 0.95 - 0.1 = 0.85) + assert not trainer.should_terminate + h(None) # score=0.93 (no improvement: 0.93 >= 0.94 - 0.1 = 0.84) + assert trainer.should_terminate + + +def test_early_stopping_min_mode_rel_delta(): + scores = iter([1.0, 0.8, 0.79, 0.78]) + + trainer = Engine(do_nothing_update_fn) + + h = EarlyStopping( + patience=2, + min_delta=0.1, + min_delta_mode="rel", + score_function=lambda _: next(scores), + trainer=trainer, + mode="min", + ) + + assert not trainer.should_terminate + h(None) # best_score=1.0 + assert not trainer.should_terminate + h(None) # score=0.8 (improvement: 0.8 < 1.0 * (1 - 0.1) = 0.9) + assert not trainer.should_terminate + h(None) # score=0.79 (no improvement: 0.79 >= 0.8 * (1 - 0.1) = 0.72) + assert not trainer.should_terminate + h(None) # score=0.78 (no improvement) + assert trainer.should_terminate + + def _test_distrib_with_engine_early_stopping(device): if device is None: device = idist.device()