Skip to content

Commit

Permalink
393 add Keep Largest Connected Component transform (Project-MONAI#410)
Browse files Browse the repository at this point in the history
* 393 add Keep Largest Connected Component transform

* [MONAI] python code formatting

* 1. Use PyTorch tensor instead of numpy array

2. Move to post folder and add docs

3. Add more tests

* [MONAI] python code formatting

* Update according to feedbacks

* [MONAI] python code formatting

* Fix tests with GPU unavailable

* Update based on reviews

Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
  • Loading branch information
YuanTingHsieh and monai-bot authored May 27, 2020
1 parent 1d73f65 commit eb10439
Show file tree
Hide file tree
Showing 6 changed files with 462 additions and 1 deletion.
10 changes: 10 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ Vanilla Transforms
:members:
:special-members: __call__

`KeepLargestConnectedComponent`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: KeepLargestConnectedComponent
:members:
:special-members: __call__

Dictionary-based Transforms
---------------------------
Expand Down Expand Up @@ -579,6 +584,11 @@ Dictionary-based Transforms
:members:
:special-members: __call__

`KeepLargestConnectedComponentd`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: KeepLargestConnectedComponentd
:members:
:special-members: __call__

Transform Adaptors
------------------
Expand Down
94 changes: 94 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
"""

import torch

from monai.transforms.compose import Transform
from monai.networks.utils import one_hot
from monai.transforms.utils import get_largest_connected_component_mask


class SplitChannel(Transform):
Expand Down Expand Up @@ -46,3 +49,94 @@ def __call__(self, img, to_onehot=None, num_classes=None):
outputs.append(img[:, i : i + 1])

return outputs


class KeepLargestConnectedComponent(Transform):
"""
Keeps only the largest connected component in the image.
This transform can be used as a post-processing step to clean up over-segment areas in model output.
The input is assumed to be a PyTorch Tensor with shape (batch_size, 1, spatial_dim1[, spatial_dim2, ...])
Expected input data should have only 1 channel and the values correspond to expected labels.
For example:
Use KeepLargestConnectedComponent with applied_values=[1], connectivity=1
[1, 0, 0] [0, 0, 0]
[0, 1, 1] => [0, 1 ,1]
[0, 1, 1] [0, 1, 1]
Use KeepLargestConnectedComponent with applied_values[1, 2], independent=False, connectivity=1
[0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0]
[0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1]
[1, 2, 1, 0 ,0] => [1, 2, 1, 0 ,0]
[1, 2, 0, 1 ,0] [1, 2, 0, 0 ,0]
[2, 2, 0, 0 ,2] [2, 2, 0, 0 ,0]
Use KeepLargestConnectedComponent with applied_values[1, 2], independent=True, connectivity=1
[0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0]
[0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1]
[1, 2, 1, 0 ,0] => [0, 2, 1, 0 ,0]
[1, 2, 0, 1 ,0] [0, 2, 0, 0 ,0]
[2, 2, 0, 0 ,2] [2, 2, 0, 0 ,0]
Use KeepLargestConnectedComponent with applied_values[1, 2], independent=False, connectivity=2
[0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0]
[0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1]
[1, 2, 1, 0 ,0] => [1, 2, 1, 0 ,0]
[1, 2, 0, 1 ,0] [1, 2, 0, 1 ,0]
[2, 2, 0, 0 ,2] [2, 2, 0, 0 ,2]
"""

def __init__(self, applied_values, independent=True, background=0, connectivity=None):
"""
Args:
applied_values (list or tuple of int): number list for applying the connected component on.
The pixel whose value is not in this list will remain unchanged.
independent (bool): consider several labels as a whole or independent, default is `True`.
Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case
you want this "independent" to be specified as False.
background (int): Background pixel value. The over-segmented pixels will be set as this value.
connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used.
"""
super().__init__()
self.applied_values = applied_values
self.independent = independent
self.background = background
self.connectivity = connectivity
if background in applied_values:
raise ValueError("Background pixel can't be in applied_values.")

def __call__(self, img):
"""
Args:
img: shape must be (batch_size, 1, spatial_dim1[, spatial_dim2, ...]).
Returns:
A PyTorch Tensor with shape (batch_size, 1, spatial_dim1[, spatial_dim2, ...]).
"""
channel_dim = 1
if img.shape[channel_dim] == 1:
img = torch.squeeze(img, dim=channel_dim)
else:
raise ValueError("Input data have more than 1 channel.")

if self.independent:
for i in self.applied_values:
foreground = (img == i).type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[foreground != mask] = self.background
else:
foreground = torch.zeros_like(img)
for i in self.applied_values:
foreground += (img == i).type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[foreground != mask] = self.background

return torch.unsqueeze(img, dim=channel_dim)
41 changes: 40 additions & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from monai.utils.misc import ensure_tuple_rep
from monai.transforms.compose import MapTransform
from monai.transforms.post.array import SplitChannel
from monai.transforms.post.array import SplitChannel, KeepLargestConnectedComponent


class SplitChanneld(MapTransform):
Expand Down Expand Up @@ -56,4 +56,43 @@ def __call__(self, data):
return d


class KeepLargestConnectedComponentd(MapTransform):
"""
dictionary-based wrapper of :py:class:monai.transforms.utility.array.KeepLargestConnectedComponent.
"""

def __init__(
self, keys, applied_values, independent=True, background=0, connectivity=None, output_postfix="largestcc",
):
"""
Args:
keys (hashable items): keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
applied_values (list or tuple of int): number list for applying the connected component on.
The pixel whose value is not in this list will remain unchanged.
independent (bool): consider several labels as a whole or independent, default is `True`.
Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case
you want this "independent" to be specified as False.
background (int): Background pixel value. The over-segmented pixels will be set as this value.
connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used.
output_postfix (str): the postfix string to construct keys to store converted data.
for example: if the keys of input data is `label`, output_postfix is `largestcc`,
the output data keys will be: `label_largestcc`.
"""
super().__init__(keys)
if not isinstance(output_postfix, str):
raise ValueError("output_postfix must be a string.")
self.output_postfix = output_postfix
self.converter = KeepLargestConnectedComponent(applied_values, independent, background, connectivity)

def __call__(self, data):
d = dict(data)
for idx, key in enumerate(self.keys):
d[f"{key}_{self.output_postfix}"] = self.converter(d[key])
return d


SplitChannelD = SplitChannelDict = SplitChanneld
KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd
21 changes: 21 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import random
import warnings

import torch
import numpy as np
from skimage import measure

from monai.utils.misc import ensure_tuple

Expand Down Expand Up @@ -395,3 +397,22 @@ def generate_spatial_bounding_box(img, select_fn=lambda x: x > 0, channel_indexe
box_start.append(max(0, np.min(nonzero_idx[i]) - margin))
box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1))
return box_start, box_end


def get_largest_connected_component_mask(img, connectivity=None):
"""
Gets the largest connected component mask of an image.
Args:
img: Image to get largest connected component from. Shape is (batch_size, spatial_dim1 [, spatial_dim2, ...])
connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used.
"""
img_arr = img.detach().cpu().numpy()
largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype)
for i, item in enumerate(img_arr):
item = measure.label(item, connectivity=connectivity)
if item.max() != 0:
largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1)
return torch.as_tensor(largest_cc, device=img.device)
145 changes: 145 additions & 0 deletions tests/test_keep_largest_connected_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import KeepLargestConnectedComponent

grid_1 = torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]])
grid_2 = torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]])

TEST_CASE_1 = [
"value_1",
{"independent": False, "applied_values": [1]},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]),
]

TEST_CASE_2 = [
"value_2",
{"independent": False, "applied_values": [2]},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]),
]

TEST_CASE_3 = [
"independent_value_1_2",
{"independent": True, "applied_values": [1, 2]},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]),
]

TEST_CASE_4 = [
"dependent_value_1_2",
{"independent": False, "applied_values": [1, 2]},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]),
]

TEST_CASE_5 = [
"value_1",
{"independent": True, "applied_values": [1]},
grid_2,
torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]),
]

TEST_CASE_6 = [
"independent_value_1_2",
{"independent": True, "applied_values": [1, 2]},
grid_2,
torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]),
]

TEST_CASE_7 = [
"dependent_value_1_2",
{"independent": False, "applied_values": [1, 2]},
grid_2,
torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]),
]

TEST_CASE_8 = [
"value_1_connect_1",
{"independent": False, "applied_values": [1], "connectivity": 1},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]]),
]

TEST_CASE_9 = [
"independent_value_1_2_connect_1",
{"independent": True, "applied_values": [1, 2], "connectivity": 1},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]),
]

TEST_CASE_10 = [
"dependent_value_1_2_connect_1",
{"independent": False, "applied_values": [1, 2], "connectivity": 1},
grid_1,
torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]),
]

TEST_CASE_11 = [
"value_0_background_3",
{"independent": False, "applied_values": [0], "background": 3},
grid_1,
torch.tensor([[[[3, 3, 1, 3, 3], [3, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]),
]

TEST_CASE_12 = [
"all_0_batch_2",
{"independent": False, "applied_values": [1], "background": 3},
torch.tensor(
[
[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]],
[[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0]]],
]
),
torch.tensor(
[
[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]],
[[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 3, 3, 3], [0, 0, 3, 0, 0], [0, 0, 0, 0, 0]]],
]
),
]

VALID_CASES = [
TEST_CASE_1,
TEST_CASE_2,
TEST_CASE_3,
TEST_CASE_4,
TEST_CASE_5,
TEST_CASE_6,
TEST_CASE_7,
TEST_CASE_8,
TEST_CASE_9,
TEST_CASE_10,
TEST_CASE_11,
TEST_CASE_12,
]


class TestKeepLargestConnectedComponent(unittest.TestCase):
@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, args, tensor, expected):
converter = KeepLargestConnectedComponent(**args)
if torch.cuda.is_available():
result = converter(tensor.clone().cuda())
assert torch.allclose(result, expected.cuda())
else:
result = converter(tensor.clone())
assert torch.allclose(result, expected)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit eb10439

Please sign in to comment.