Skip to content

Commit

Permalink
457 enhance handlers to support attach and logger (Project-MONAI#458)
Browse files Browse the repository at this point in the history
* [DLMED] fix flake8 error

* [DLMED] revert

* [DLMED] enhance all handlers

* [DLMED] fix windows test error

* [DLMED] fix typo

* [DLMED] update according to the comments
  • Loading branch information
Nic-Ma authored May 31, 2020
1 parent f696736 commit a683c4e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 8 deletions.
5 changes: 3 additions & 2 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
See also:
:py:meth:`monai.metrics.meandice.compute_meandice`
"""
super(MeanDice, self).__init__(output_transform, device=device)
super().__init__(output_transform, device=device)
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.mutually_exclusive = mutually_exclusive
Expand All @@ -67,7 +67,8 @@ def reset(self):

@reinit__is_reduced
def update(self, output: Sequence[Union[torch.Tensor, dict]]):
assert len(output) == 2, "MeanDice metric can only support y_pred and y."
if not len(output) == 2:
raise ValueError("MeanDice metric can only support y_pred and y.")
y_pred, y = output
scores = compute_meandice(
y_pred,
Expand Down
14 changes: 11 additions & 3 deletions monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence
from typing import Optional, Sequence, Union

import torch
from ignite.metrics import Metric
Expand Down Expand Up @@ -39,14 +39,22 @@ class ROCAUC(Metric):
:class:`~ignite.engine.Engine` `process_function` output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (torch.device): device specification in case of distributed computation usage.
Note:
ROCAUC expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence values.
"""

def __init__(self, to_onehot_y=False, add_softmax=False, average="macro", output_transform=lambda x: x):
super().__init__(output_transform=output_transform)
def __init__(
self,
to_onehot_y=False,
add_softmax=False,
average="macro",
output_transform=lambda x: x,
device: Optional[Union[str, torch.device]] = None,
):
super().__init__(output_transform, device=device)
self.to_onehot_y = to_onehot_y
self.add_softmax = add_softmax
self.average = average
Expand Down
12 changes: 11 additions & 1 deletion monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
import torch
from ignite.engine import Engine, Events
from monai.utils.misc import is_scalar
from monai.utils import is_scalar

DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} "
DEFAULT_TAG = "Loss"
Expand All @@ -40,6 +40,7 @@ def __init__(
name=None,
tag_name=DEFAULT_TAG,
key_var_format=DEFAULT_KEY_VAL_FORMAT,
logger_handler=None,
):
"""
Expand All @@ -59,6 +60,8 @@ def __init__(
tag_name (string): when iteration output is a scalar, tag_name is used to print
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
key_var_format (string): a formatting string to control the output string format of key: value.
logger_handler (logging.handler): add additional handler to handle the stats data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
"""

self.epoch_print_logger = epoch_print_logger
Expand All @@ -69,6 +72,8 @@ def __init__(

self.tag_name = tag_name
self.key_var_format = key_var_format
if logger_handler is not None:
self.logger.addHandler(logger_handler)

def attach(self, engine: Engine):
"""Register a set of Ignite Event-Handlers to a specified Ignite engine.
Expand Down Expand Up @@ -142,7 +147,12 @@ def _default_epoch_print(self, engine: Engine):
for name in sorted(prints_dict):
value = prints_dict[name]
out_str += self.key_var_format.format(name, value)
self.logger.info(out_str)

if hasattr(engine.state, "key_metric_name"):
if hasattr(engine.state, "best_metric") and hasattr(engine.state, "best_metric_epoch"):
out_str = f"Key metric: {engine.state.key_metric_name} "
out_str += f"best value: {engine.state.best_metric} at epoch: {engine.state.best_metric_epoch}"
self.logger.info(out_str)

def _default_iteration_print(self, engine: Engine):
Expand Down
15 changes: 14 additions & 1 deletion monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def __init__(
self,
summary_writer=None,
log_dir="./runs",
interval=1,
epoch_level=True,
batch_transform=lambda x: x,
output_transform=lambda x: x,
global_iter_transform=lambda x: x,
Expand All @@ -194,6 +196,9 @@ def __init__(
summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter,
default to create a new writer.
log_dir (str): if using default SummaryWriter, write logs to this directory, default is `./runs`.
interval (int): plot content from engine.state every N epochs or every N iterations, default is 1.
epoch_level (bool): plot content from engine.state every N epochs or N iterations. `True` is epoch level,
`False` is iteration level.
batch_transform (Callable): a callable that is used to transform the
``ignite.engine.batch`` into expected format to extract several label data.
output_transform (Callable): a callable that is used to transform the
Expand All @@ -205,15 +210,23 @@ def __init__(
max_frames (int): number of frames for 2D-t plot.
"""
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
self.interval = interval
self.epoch_level = epoch_level
self.batch_transform = batch_transform
self.output_transform = output_transform
self.global_iter_transform = global_iter_transform
self.index = index
self.max_frames = max_frames
self.max_channels = max_channels

def attach(self, engine):
if self.epoch_level:
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self)
else:
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self)

def __call__(self, engine):
step = self.global_iter_transform(engine.state.iteration)
step = self.global_iter_transform(engine.state.epoch if self.epoch_level else engine.state.iteration)

show_images = self.batch_transform(engine.state.batch)[0]
if torch.is_tensor(show_images):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_value(self, input_param, input_data, expected_print):
@parameterized.expand([TEST_CASE_6])
def test_file(self, input_data, expected_print):
tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "test_stats.log")
filename = os.path.join(tempdir, "test_data_stats.log")
handler = logging.FileHandler(filename, mode="w")
input_param = {
"prefix": "test data",
Expand Down
35 changes: 35 additions & 0 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# limitations under the License.

import torch
import os
import shutil
import logging
import tempfile
import re
import unittest
from io import StringIO
Expand Down Expand Up @@ -108,6 +111,38 @@ def _train_func(engine, batch):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))

def test_loss_file(self):
logging.basicConfig(level=logging.INFO)
key_to_handler = "test_logging"
key_to_print = "myLoss"

tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "test_loss_stats.log")
handler = logging.FileHandler(filename, mode="w")

# set up engine
def _train_func(engine, batch):
return torch.tensor(0.0)

engine = Engine(_train_func)

# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
handler.stream.close()
stats_handler.logger.removeHandler(handler)
with open(filename, "r") as f:
output_str = f.read()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
shutil.rmtree(tempdir)


if __name__ == "__main__":
unittest.main()

0 comments on commit a683c4e

Please sign in to comment.