Skip to content

Commit

Permalink
433 add post_transform to engines (Project-MONAI#468)
Browse files Browse the repository at this point in the history
* [DLMED] add post_transform to engines

* [MONAI] python code formatting

* [DLMED] update according to the comments

* [DLMED] update notebook

* [DLMED] update notebook

Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
Nic-Ma and monai-bot authored Jun 2, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent a78cd68 commit 7cd6561
Showing 5 changed files with 62 additions and 37 deletions.
37 changes: 7 additions & 30 deletions examples/notebooks/spleen_segmentation_3d.ipynb
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"MONAI version: 0.1.0+67.g6b04999.dirty\n",
"MONAI version: 0.1.0+94.gae28c16.dirty\n",
"Python version: 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31) [GCC 7.3.0]\n",
"Numpy version: 1.17.4\n",
"Pytorch version: 1.5.0\n",
@@ -65,13 +65,11 @@
"import glob\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import matplotlib.pyplot as plt\n",
"import monai\n",
"from monai.transforms import \\\n",
" Compose, LoadNiftid, AddChanneld, ScaleIntensityRanged, CropForegroundd, \\\n",
" RandCropByPosNegLabeld, RandAffined, Spacingd, Orientationd, ToTensord\n",
"from monai.data import list_data_collate\n",
"from monai.inferers import sliding_window_inference\n",
"from monai.networks.layers import Norm\n",
"from monai.metrics import compute_meandice\n",
@@ -116,26 +114,6 @@
"set_determinism(seed=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set different seed for workers of DataLoader\n",
"This is known issue of PyTorch: \n",
"https://discuss.pytorch.org/t/why-does-numpy-random-rand-produce-the-same-values-in-different-cores/12005"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def worker_init_fn(worker_id):\n",
" worker_info = torch.utils.data.get_worker_info()\n",
" worker_info.dataset.transform.set_random_state(worker_info.seed % (2 ** 32))"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -145,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -185,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -210,7 +188,7 @@
],
"source": [
"check_ds = monai.data.Dataset(data=val_files, transform=val_transforms)\n",
"check_loader = DataLoader(check_ds, batch_size=1)\n",
"check_loader = monai.data.DataLoader(check_ds, batch_size=1)\n",
"check_data = monai.utils.misc.first(check_loader)\n",
"image, label = (check_data['image'][0][0], check_data['label'][0][0])\n",
"print('image shape: {}, label shape: {}'.format(image.shape, label.shape))\n",
@@ -239,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -261,14 +239,13 @@
"\n",
"# use batch_size=2 to load images and use RandCropByPosNegLabeld\n",
"# to generate 2 x 4 images for network training\n",
"train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate,\n",
" worker_init_fn=worker_init_fn)\n",
"train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)\n",
"\n",
"val_ds = monai.data.CacheDataset(\n",
" data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4\n",
")\n",
"# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)\n",
"val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)"
"val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)"
]
},
{
30 changes: 25 additions & 5 deletions examples/workflows/unet_training_dict.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
import numpy as np
import torch
from torch.utils.data import DataLoader
from ignite.metrics import Accuracy

import monai
from monai.transforms import (
@@ -29,6 +30,9 @@
RandCropByPosNegLabeld,
RandRotate90d,
ToTensord,
Activationsd,
AsDiscreted,
KeepLargestConnectedComponentd,
)
from monai.handlers import StatsHandler, ValidationHandler, MeanDice
from monai.data import create_test_image_3d, list_data_collate
@@ -99,22 +103,37 @@ def main():
loss = monai.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), 1e-3)

val_post_transforms = Compose(
[
Activationsd(keys=Keys.PRED, output_postfix="act", sigmoid=True),
AsDiscreted(keys="pred_act", output_postfix="dis", threshold_values=True),
KeepLargestConnectedComponentd(keys="pred_act_dis", applied_values=[1], output_postfix=None),
]
)
val_handlers = [StatsHandler(output_transform=lambda x: None)]

evaluator = SupervisedEvaluator(
device=device,
val_data_loader=val_loader,
network=net,
inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
val_handlers=val_handlers,
post_transform=val_post_transforms,
key_val_metric={
"val_mean_dice": MeanDice(
include_background=True, add_sigmoid=True, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])
include_background=True, output_transform=lambda x: (x["pred_act_dis"], x[Keys.LABEL])
)
},
additional_metrics=None,
additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred_act_dis"], x[Keys.LABEL]))},
val_handlers=val_handlers,
)

train_post_transforms = Compose(
[
Activationsd(keys=Keys.PRED, output_postfix="act", sigmoid=True),
AsDiscreted(keys="pred_act", output_postfix="dis", threshold_values=True),
KeepLargestConnectedComponentd(keys="pred_act_dis", applied_values=[1], output_postfix=None),
]
)
train_handlers = [
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
StatsHandler(tag_name="train_loss", output_transform=lambda x: x[Keys.INFO][Keys.LOSS]),
@@ -128,9 +147,10 @@ def main():
optimizer=opt,
loss_function=loss,
inferer=SimpleInferer(),
train_handlers=train_handlers,
amp=False,
key_train_metric=None,
post_transform=train_post_transforms,
key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred_act_dis"], x[Keys.LABEL]))},
train_handlers=train_handlers,
)
trainer.run()

16 changes: 15 additions & 1 deletion monai/engines/evaluator.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@ class Evaluator(Workflow):
prepare_batch (Callable): function to parse image and label for current iteration.
iteration_update (Callable): the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
post_transform (Transform): execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric (ignite.metric): compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
checkpoint into files.
@@ -45,6 +47,7 @@ def __init__(
val_data_loader,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
post_transform=None,
key_val_metric: Optional[Metric] = None,
additional_metrics=None,
val_handlers=None,
@@ -56,6 +59,7 @@ def __init__(
data_loader=val_data_loader,
prepare_batch=prepare_batch,
iteration_update=iteration_update,
post_transform=post_transform,
key_metric=key_val_metric,
additional_metrics=additional_metrics,
handlers=val_handlers,
@@ -94,6 +98,8 @@ class SupervisedEvaluator(Evaluator):
iteration_update (Callable): the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
inferer (Inferer): inference method that execute model forward on input data, like: SlidingWindow, etc.
post_transform (Transform): execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric (ignite.metric): compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
checkpoint into files.
@@ -111,12 +117,20 @@ def __init__(
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
inferer=SimpleInferer(),
post_transform=None,
key_val_metric=None,
additional_metrics=None,
val_handlers=None,
):
super().__init__(
device, val_data_loader, prepare_batch, iteration_update, key_val_metric, additional_metrics, val_handlers
device=device,
val_data_loader=val_data_loader,
prepare_batch=prepare_batch,
iteration_update=iteration_update,
post_transform=post_transform,
key_val_metric=key_val_metric,
additional_metrics=additional_metrics,
val_handlers=val_handlers,
)

self.network = network
6 changes: 5 additions & 1 deletion monai/engines/trainer.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,9 @@ class SupervisedTrainer(Trainer):
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
lr_scheduler (LR Scheduler): the lr scheduler associated to the optimizer.
inferer (Inferer): inference method that execute model forward on input data, like: SlidingWindow, etc.
amp (bool): whether to enable auto-mixed-precision training.
amp (bool): whether to enable auto-mixed-precision training, reserved.
post_transform (Transform): execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_train_metric (ignite.metric): compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
checkpoint into files.
@@ -79,6 +81,7 @@ def __init__(
lr_scheduler=None,
inferer=SimpleInferer(),
amp: bool = True,
post_transform=None,
key_train_metric: Optional[Metric] = None,
additional_metrics=None,
train_handlers=None,
@@ -94,6 +97,7 @@ def __init__(
key_metric=key_train_metric,
additional_metrics=additional_metrics,
handlers=train_handlers,
post_transform=post_transform,
)

self.network = network
10 changes: 10 additions & 0 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
import torch
from ignite.engine import Engine, State, Events
from .utils import default_prepare_batch
from monai.transforms import apply_transform


class Workflow(ABC, Engine):
@@ -33,6 +34,8 @@ class Workflow(ABC, Engine):
prepare_batch (Callable): function to parse image and label for every iteration.
iteration_update (Callable): the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
post_transform (Transform): execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_metric (ignite.metric): compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the
checkpoint into files.
@@ -50,6 +53,7 @@ def __init__(
data_loader,
prepare_batch=default_prepare_batch,
iteration_update=None,
post_transform=None,
key_metric=None,
additional_metrics=None,
handlers=None,
@@ -87,6 +91,12 @@ def __init__(
self.data_loader = data_loader
self.prepare_batch = prepare_batch

if post_transform is not None:

@self.on(Events.ITERATION_COMPLETED)
def run_post_transform(engine):
engine.state.output = apply_transform(post_transform, engine.state.output)

metrics = None
if key_metric is not None:

0 comments on commit 7cd6561

Please sign in to comment.