From 5ceacbb3ccea433317df2776556bd4a08d0a923d Mon Sep 17 00:00:00 2001 From: Sylvain Desroziers Date: Tue, 5 Jan 2021 22:51:08 +0100 Subject: [PATCH] Use events list for logger handlers (#1544) * use events list for loggers * autopep8 fix * add test * fix mypy * add test to catch error * improve docstring Co-authored-by: Desroziers Co-authored-by: sdesrozis --- ignite/contrib/handlers/base_logger.py | 24 ++++++++++++------- ignite/engine/__init__.py | 4 +++- ignite/engine/events.py | 2 +- .../contrib/handlers/test_base_logger.py | 19 ++++++++++++--- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 44707d668d7..c3776835449 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch.optim import Optimizer -from ignite.engine import Engine, Events, State +from ignite.engine import Engine, Events, EventsList, State from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle @@ -147,7 +147,7 @@ class BaseLogger(metaclass=ABCMeta): """ def attach( - self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter] + self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList] ) -> RemovableEventHandle: """Attach the logger to the engine and execute `log_handler` function at `event_name` events. @@ -155,18 +155,26 @@ def attach( engine (Engine): engine object. log_handler (callable): a logging handler to execute event_name: event to attach the logging handler to. Valid events are from - :class:`~ignite.engine.events.Events` or any `event_name` added by - :meth:`~ignite.engine.engine.Engine.register_events`. + :class:`~ignite.engine.events.Events` or class:`~ignite.engine.events.EventsList` or any `event_name` + added by :meth:`~ignite.engine.engine.Engine.register_events`. Returns: :class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler. """ - name = event_name + if isinstance(event_name, EventsList): + for name in event_name: + if name not in State.event_to_attr: + raise RuntimeError(f"Unknown event name '{name}'") + engine.add_event_handler(name, log_handler, self, name) + + return RemovableEventHandle(event_name, log_handler, engine) + + else: - if name not in State.event_to_attr: - raise RuntimeError(f"Unknown event name '{name}'") + if event_name not in State.event_to_attr: + raise RuntimeError(f"Unknown event name '{event_name}'") - return engine.add_event_handler(event_name, log_handler, self, name) + return engine.add_event_handler(event_name, log_handler, self, event_name) def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Shortcut method to attach `OutputHandler` to the logger. diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 3fbc83861b3..5330393a2cc 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -6,7 +6,7 @@ import ignite.distributed as idist from ignite.engine.deterministic import DeterministicEngine from ignite.engine.engine import Engine -from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State +from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, RemovableEventHandle, State from ignite.metrics import Metric from ignite.utils import convert_tensor @@ -21,8 +21,10 @@ "Engine", "DeterministicEngine", "Events", + "EventsList", "EventEnum", "CallableEventWithFilter", + "RemovableEventHandle", ] diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 4f07f576508..aef9c97bb10 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from ignite.engine.engine import Engine -__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"] +__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State", "EventsList", "RemovableEventHandle"] class CallableEventWithFilter: diff --git a/tests/ignite/contrib/handlers/test_base_logger.py b/tests/ignite/contrib/handlers/test_base_logger.py index 3abc162f693..4b6b8350dbe 100644 --- a/tests/ignite/contrib/handlers/test_base_logger.py +++ b/tests/ignite/contrib/handlers/test_base_logger.py @@ -1,11 +1,11 @@ import math -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest import torch from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler -from ignite.engine import Engine, Events, State +from ignite.engine import Engine, Events, EventsList, State from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer @@ -122,7 +122,12 @@ def update_fn(engine, batch): trainer.run(data, max_epochs=n_epochs) - mock_log_handler.assert_called_with(trainer, logger, event) + if isinstance(event, EventsList): + events = [e for e in event] + else: + events = [event] + calls = [call(trainer, logger, e) for e in events] + mock_log_handler.assert_has_calls(calls) assert mock_log_handler.call_count == n_calls _test(Events.ITERATION_STARTED, len(data) * n_epochs) @@ -134,6 +139,8 @@ def update_fn(engine, batch): _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs) + _test(Events.STARTED | Events.COMPLETED, 2) + def test_attach_wrong_event_name(): @@ -144,6 +151,12 @@ def test_attach_wrong_event_name(): with pytest.raises(RuntimeError, match="Unknown event name"): logger.attach(trainer, log_handler=mock_log_handler, event_name="unknown") + events_list = EventsList() + events_list._events = ["unknown"] + + with pytest.raises(RuntimeError, match="Unknown event name"): + logger.attach(trainer, log_handler=mock_log_handler, event_name=events_list) + def test_attach_on_custom_event(): n_epochs = 10