Skip to content

Commit

Permalink
Use events list for logger handlers (#1544)
Browse files Browse the repository at this point in the history
* use events list for loggers

* autopep8 fix

* add test

* fix mypy

* add test to catch error

* improve docstring

Co-authored-by: Desroziers <sylvain.desroziers@ifpen.fr>
Co-authored-by: sdesrozis <sdesrozis@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 5, 2021
1 parent f3fc875 commit 5ceacbb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
24 changes: 16 additions & 8 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from ignite.engine import Engine, Events, State
from ignite.engine import Engine, Events, EventsList, State
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle


Expand Down Expand Up @@ -147,26 +147,34 @@ class BaseLogger(metaclass=ABCMeta):
"""

def attach(
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter]
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList]
) -> RemovableEventHandle:
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.
Args:
engine (Engine): engine object.
log_handler (callable): a logging handler to execute
event_name: event to attach the logging handler to. Valid events are from
:class:`~ignite.engine.events.Events` or any `event_name` added by
:meth:`~ignite.engine.engine.Engine.register_events`.
:class:`~ignite.engine.events.Events` or class:`~ignite.engine.events.EventsList` or any `event_name`
added by :meth:`~ignite.engine.engine.Engine.register_events`.
Returns:
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
"""
name = event_name
if isinstance(event_name, EventsList):
for name in event_name:
if name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{name}'")
engine.add_event_handler(name, log_handler, self, name)

return RemovableEventHandle(event_name, log_handler, engine)

else:

if name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{name}'")
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")

return engine.add_event_handler(event_name, log_handler, self, name)
return engine.add_event_handler(event_name, log_handler, self, event_name)

def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
"""Shortcut method to attach `OutputHandler` to the logger.
Expand Down
4 changes: 3 additions & 1 deletion ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ignite.distributed as idist
from ignite.engine.deterministic import DeterministicEngine
from ignite.engine.engine import Engine
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, RemovableEventHandle, State
from ignite.metrics import Metric
from ignite.utils import convert_tensor

Expand All @@ -21,8 +21,10 @@
"Engine",
"DeterministicEngine",
"Events",
"EventsList",
"EventEnum",
"CallableEventWithFilter",
"RemovableEventHandle",
]


Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
from ignite.engine.engine import Engine

__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"]
__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State", "EventsList", "RemovableEventHandle"]


class CallableEventWithFilter:
Expand Down
19 changes: 16 additions & 3 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
from unittest.mock import MagicMock
from unittest.mock import MagicMock, call

import pytest
import torch

from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.engine import Engine, Events, State
from ignite.engine import Engine, Events, EventsList, State
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer


Expand Down Expand Up @@ -122,7 +122,12 @@ def update_fn(engine, batch):

trainer.run(data, max_epochs=n_epochs)

mock_log_handler.assert_called_with(trainer, logger, event)
if isinstance(event, EventsList):
events = [e for e in event]
else:
events = [event]
calls = [call(trainer, logger, e) for e in events]
mock_log_handler.assert_has_calls(calls)
assert mock_log_handler.call_count == n_calls

_test(Events.ITERATION_STARTED, len(data) * n_epochs)
Expand All @@ -134,6 +139,8 @@ def update_fn(engine, batch):

_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)

_test(Events.STARTED | Events.COMPLETED, 2)


def test_attach_wrong_event_name():

Expand All @@ -144,6 +151,12 @@ def test_attach_wrong_event_name():
with pytest.raises(RuntimeError, match="Unknown event name"):
logger.attach(trainer, log_handler=mock_log_handler, event_name="unknown")

events_list = EventsList()
events_list._events = ["unknown"]

with pytest.raises(RuntimeError, match="Unknown event name"):
logger.attach(trainer, log_handler=mock_log_handler, event_name=events_list)


def test_attach_on_custom_event():
n_epochs = 10
Expand Down

0 comments on commit 5ceacbb

Please sign in to comment.