From ebd1876a12ebe16403889e0ede6de61d84c1b44b Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 21 Jan 2020 23:39:52 +0100 Subject: [PATCH] Improves #694 (#712) * Made all loggers public attribute - method to setup logger * Fixed flake8 --- ignite/engine/engine.py | 29 ++++++------ ignite/handlers/early_stopping.py | 7 ++- ignite/handlers/terminate_on_nan.py | 8 ++-- ignite/utils.py | 69 ++++++++++++++++++++++++++++- tests/ignite/test_utils.py | 47 +++++++++++++++++++- 5 files changed, 135 insertions(+), 25 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index d5b47170c3e..52febd51230 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -355,8 +355,7 @@ def compute_mean_std(engine, batch): def __init__(self, process_function): self._event_handlers = defaultdict(list) - self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) - self._logger.addHandler(logging.NullHandler()) + self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._process_function = process_function self.last_event_name = None self.should_terminate = False @@ -489,14 +488,14 @@ def print_epoch(engine): handler = Engine._handler_wrapper(handler, event_name, event_filter) if event_name not in self._allowed_events: - self._logger.error("attempt to add event handler to an invalid event %s.", event_name) + self.logger.error("attempt to add event handler to an invalid event %s.", event_name) raise ValueError("Event {} is not a valid event for this Engine.".format(event_name)) event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else () Engine._check_signature(self, handler, 'handler', *(event_args + args), **kwargs) self._event_handlers[event_name].append((handler, args, kwargs)) - self._logger.debug("added handler for event %s.", event_name) + self.logger.debug("added handler for event %s.", event_name) return RemovableEventHandle(event_name, handler, self) @@ -601,7 +600,7 @@ def _fire_event(self, event_name, *event_args, **event_kwargs): """ if event_name in self._allowed_events: - self._logger.debug("firing handlers for event %s ", event_name) + self.logger.debug("firing handlers for event %s ", event_name) self.last_event_name = event_name for func, args, kwargs in self._event_handlers[event_name]: kwargs.update(event_kwargs) @@ -633,14 +632,14 @@ def fire_event(self, event_name): def terminate(self): """Sends terminate signal to the engine, so that it terminates completely the run after the current iteration. """ - self._logger.info("Terminate signaled. Engine will stop after current iteration is finished.") + self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True def terminate_epoch(self): """Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration. """ - self._logger.info("Terminate current epoch is signaled. " - "Current epoch iteration will stop after current iteration is finished.") + self.logger.info("Terminate current epoch is signaled. " + "Current epoch iteration will stop after current iteration is finished.") self.should_terminate_single_epoch = True def _run_once_on_dataset(self): @@ -702,7 +701,7 @@ def _run_once_on_dataset(self): break except BaseException as e: - self._logger.error("Current run is terminating due to exception: %s.", str(e)) + self.logger.error("Current run is terminating due to exception: %s.", str(e)) self._handle_exception(e) time_taken = time.time() - start_time @@ -835,7 +834,7 @@ def switch_batch(engine): else: raise ValueError("Argument `epoch_length` should be defined if `data` is an iterator") self.state = State(seed=seed, iteration=0, epoch=0, max_epochs=max_epochs, epoch_length=epoch_length) - self._logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) + self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs)) else: # Keep actual state and override it if input args provided if max_epochs is not None: @@ -844,8 +843,8 @@ def switch_batch(engine): self.state.seed = seed if epoch_length is not None: self.state.epoch_length = epoch_length - self._logger.info("Engine run resuming from iteration {}, epoch {} until {} epochs" - .format(self.state.iteration, self.state.epoch, self.state.max_epochs)) + self.logger.info("Engine run resuming from iteration {}, epoch {} until {} epochs" + .format(self.state.iteration, self.state.epoch, self.state.max_epochs)) self.state.dataloader = data return self._internal_run() @@ -937,7 +936,7 @@ def _internal_run(self): hours, mins, secs = self._run_once_on_dataset() - self._logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) + self.logger.info("Epoch[%s] Complete. Time taken: %02d:%02d:%02d", self.state.epoch, hours, mins, secs) if self.should_terminate: break self._fire_event(Events.EPOCH_COMPLETED) @@ -945,11 +944,11 @@ def _internal_run(self): self._fire_event(Events.COMPLETED) time_taken = time.time() - start_time hours, mins, secs = _to_hours_mins_secs(time_taken) - self._logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) + self.logger.info("Engine run complete. Time taken %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: self._dataloader_iter = self._dataloader_len = None - self._logger.error("Engine run is terminating due to exception: %s.", str(e)) + self.logger.error("Engine run is terminating due to exception: %s.", str(e)) self._handle_exception(e) self._dataloader_iter = self._dataloader_len = None diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index ba9067eb0a3..e835a524823 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -59,8 +59,7 @@ def __init__(self, patience, score_function, trainer, min_delta=0., cumulative_d self.trainer = trainer self.counter = 0 self.best_score = None - self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) - self._logger.addHandler(logging.NullHandler()) + self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) def __call__(self, engine): score = self.score_function(engine) @@ -71,9 +70,9 @@ def __call__(self, engine): if not self.cumulative_delta and score > self.best_score: self.best_score = score self.counter += 1 - self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) + self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) if self.counter >= self.patience: - self._logger.info("EarlyStopping: Stop training") + self.logger.info("EarlyStopping: Stop training") self.trainer.terminate() else: self.best_score = score diff --git a/ignite/handlers/terminate_on_nan.py b/ignite/handlers/terminate_on_nan.py index 66255e84a25..18c6f12d8e1 100644 --- a/ignite/handlers/terminate_on_nan.py +++ b/ignite/handlers/terminate_on_nan.py @@ -29,8 +29,8 @@ class TerminateOnNan: """ def __init__(self, output_transform=lambda x: x): - self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__) - self._logger.addHandler(logging.StreamHandler()) + self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) + self.logger.addHandler(logging.StreamHandler()) self._output_transform = output_transform def __call__(self, engine): @@ -47,6 +47,6 @@ def raise_error(x): try: apply_to_type(output, (numbers.Number, torch.Tensor), raise_error) except RuntimeError: - self._logger.warning("{}: Output '{}' contains NaN or Inf. Stop training" - .format(self.__class__.__name__, output)) + self.logger.warning("{}: Output '{}' contains NaN or Inf. Stop training" + .format(self.__class__.__name__, output)) engine.terminate() diff --git a/ignite/utils.py b/ignite/utils.py index d2f41c60474..d585a2149c2 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -1,5 +1,8 @@ -import torch +import os import collections.abc as collections +import logging + +import torch def convert_tensor(input_, device=None, non_blocking=False): @@ -41,3 +44,67 @@ def to_onehot(indices, num_classes): dtype=torch.uint8, device=indices.device) return onehot.scatter_(1, indices.unsqueeze(1), 1) + + +def setup_logger(name, level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s", + filepath=None, distributed_rank=0): + """Setups logger: name, level, format etc. + + Args: + name (str): new name for the logger. + level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG + format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s` + filepath (str, optional): Optional logging file path. If not None, logs are written to the file. + distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers. + + Returns: + logging.Logger + + For example, to improve logs readability when training with a trainer and evaluator: + + .. code-block:: python + + from ignite.utils import setup_logger + + trainer = ... + evaluator = ... + + trainer.logger = setup_logger("trainer") + evaluator.logger = setup_logger("evaluator") + + trainer.run(data, max_epochs=10) + + # Logs will look like + # 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5. + # 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23 + # 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1. + # 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02 + # ... + + """ + logger = logging.getLogger(name) + + if distributed_rank > 0: + return logger + + logger.setLevel(level) + + # Remove previous handlers + if logger.hasHandlers(): + for h in list(logger.handlers): + logger.removeHandler(h) + + formatter = logging.Formatter(format) + + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if filepath is not None: + fh = logging.FileHandler(filepath) + fh.setLevel(level) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index 55088f79999..155f96fd7f6 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -1,6 +1,9 @@ +import os +import logging import pytest import torch -from ignite.utils import convert_tensor, to_onehot + +from ignite.utils import convert_tensor, to_onehot, setup_logger def test_convert_tensor(): @@ -54,3 +57,45 @@ def test_to_onehot(): y_ohe = to_onehot(y, num_classes=21) y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) + + +def test_dist_setup_logger(): + + logger = setup_logger("trainer", level=logging.CRITICAL, distributed_rank=1) + assert logger.level != logging.CRITICAL + + +def test_setup_logger(capsys, dirname): + + from ignite.engine import Engine, Events + + trainer = Engine(lambda e, b: None) + evaluator = Engine(lambda e, b: None) + + fp = os.path.join(dirname, "log") + assert len(trainer.logger.handlers) == 0 + trainer.logger.addHandler(logging.NullHandler()) + trainer.logger.addHandler(logging.NullHandler()) + trainer.logger.addHandler(logging.NullHandler()) + + trainer.logger = setup_logger("trainer", filepath=fp) + evaluator.logger = setup_logger("evaluator", filepath=fp) + + assert len(trainer.logger.handlers) == 2 + assert len(evaluator.logger.handlers) == 2 + + @trainer.on(Events.EPOCH_COMPLETED) + def _(_): + evaluator.run([0, 1, 2]) + + trainer.run([0, 1, 2, 3, 4, 5], max_epochs=5) + + captured = capsys.readouterr() + err = captured.err.split('\n') + + with open(fp, "r") as h: + data = h.readlines() + + for source in [err, data]: + assert "trainer INFO: Engine run starting with max_epochs=5." in source[0] + assert "evaluator INFO: Engine run starting with max_epochs=1." in source[2]