diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 5357834027..a5a65215f0 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -317,6 +317,11 @@ Vanilla Transforms :members: :special-members: __call__ +`KeepLargestConnectedComponent` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: KeepLargestConnectedComponent + :members: + :special-members: __call__ Dictionary-based Transforms --------------------------- @@ -579,6 +584,11 @@ Dictionary-based Transforms :members: :special-members: __call__ +`KeepLargestConnectedComponentd` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: KeepLargestConnectedComponentd + :members: + :special-members: __call__ Transform Adaptors ------------------ diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 92d171fdc9..ecf705fb1c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -13,8 +13,11 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +import torch + from monai.transforms.compose import Transform from monai.networks.utils import one_hot +from monai.transforms.utils import get_largest_connected_component_mask class SplitChannel(Transform): @@ -46,3 +49,94 @@ def __call__(self, img, to_onehot=None, num_classes=None): outputs.append(img[:, i : i + 1]) return outputs + + +class KeepLargestConnectedComponent(Transform): + """ + Keeps only the largest connected component in the image. + This transform can be used as a post-processing step to clean up over-segment areas in model output. + The input is assumed to be a PyTorch Tensor with shape (batch_size, 1, spatial_dim1[, spatial_dim2, ...]) + + Expected input data should have only 1 channel and the values correspond to expected labels. + + For example: + Use KeepLargestConnectedComponent with applied_values=[1], connectivity=1 + + [1, 0, 0] [0, 0, 0] + [0, 1, 1] => [0, 1 ,1] + [0, 1, 1] [0, 1, 1] + + Use KeepLargestConnectedComponent with applied_values[1, 2], independent=False, connectivity=1 + + [0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0] + [0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1] + [1, 2, 1, 0 ,0] => [1, 2, 1, 0 ,0] + [1, 2, 0, 1 ,0] [1, 2, 0, 0 ,0] + [2, 2, 0, 0 ,2] [2, 2, 0, 0 ,0] + + Use KeepLargestConnectedComponent with applied_values[1, 2], independent=True, connectivity=1 + + [0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0] + [0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1] + [1, 2, 1, 0 ,0] => [0, 2, 1, 0 ,0] + [1, 2, 0, 1 ,0] [0, 2, 0, 0 ,0] + [2, 2, 0, 0 ,2] [2, 2, 0, 0 ,0] + + Use KeepLargestConnectedComponent with applied_values[1, 2], independent=False, connectivity=2 + + [0, 0, 1, 0 ,0] [0, 0, 1, 0 ,0] + [0, 2, 1, 1 ,1] [0, 2, 1, 1 ,1] + [1, 2, 1, 0 ,0] => [1, 2, 1, 0 ,0] + [1, 2, 0, 1 ,0] [1, 2, 0, 1 ,0] + [2, 2, 0, 0 ,2] [2, 2, 0, 0 ,2] + + """ + + def __init__(self, applied_values, independent=True, background=0, connectivity=None): + """ + Args: + applied_values (list or tuple of int): number list for applying the connected component on. + The pixel whose value is not in this list will remain unchanged. + independent (bool): consider several labels as a whole or independent, default is `True`. + Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case + you want this "independent" to be specified as False. + background (int): Background pixel value. The over-segmented pixels will be set as this value. + connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. If ``None``, a full + connectivity of ``input.ndim`` is used. + """ + super().__init__() + self.applied_values = applied_values + self.independent = independent + self.background = background + self.connectivity = connectivity + if background in applied_values: + raise ValueError("Background pixel can't be in applied_values.") + + def __call__(self, img): + """ + Args: + img: shape must be (batch_size, 1, spatial_dim1[, spatial_dim2, ...]). + + Returns: + A PyTorch Tensor with shape (batch_size, 1, spatial_dim1[, spatial_dim2, ...]). + """ + channel_dim = 1 + if img.shape[channel_dim] == 1: + img = torch.squeeze(img, dim=channel_dim) + else: + raise ValueError("Input data have more than 1 channel.") + + if self.independent: + for i in self.applied_values: + foreground = (img == i).type(torch.uint8) + mask = get_largest_connected_component_mask(foreground, self.connectivity) + img[foreground != mask] = self.background + else: + foreground = torch.zeros_like(img) + for i in self.applied_values: + foreground += (img == i).type(torch.uint8) + mask = get_largest_connected_component_mask(foreground, self.connectivity) + img[foreground != mask] = self.background + + return torch.unsqueeze(img, dim=channel_dim) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 097e8b51a5..bec7dbb3b2 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -17,7 +17,7 @@ from monai.utils.misc import ensure_tuple_rep from monai.transforms.compose import MapTransform -from monai.transforms.post.array import SplitChannel +from monai.transforms.post.array import SplitChannel, KeepLargestConnectedComponent class SplitChanneld(MapTransform): @@ -56,4 +56,43 @@ def __call__(self, data): return d +class KeepLargestConnectedComponentd(MapTransform): + """ + dictionary-based wrapper of :py:class:monai.transforms.utility.array.KeepLargestConnectedComponent. + """ + + def __init__( + self, keys, applied_values, independent=True, background=0, connectivity=None, output_postfix="largestcc", + ): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + applied_values (list or tuple of int): number list for applying the connected component on. + The pixel whose value is not in this list will remain unchanged. + independent (bool): consider several labels as a whole or independent, default is `True`. + Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case + you want this "independent" to be specified as False. + background (int): Background pixel value. The over-segmented pixels will be set as this value. + connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. If ``None``, a full + connectivity of ``input.ndim`` is used. + output_postfix (str): the postfix string to construct keys to store converted data. + for example: if the keys of input data is `label`, output_postfix is `largestcc`, + the output data keys will be: `label_largestcc`. + """ + super().__init__(keys) + if not isinstance(output_postfix, str): + raise ValueError("output_postfix must be a string.") + self.output_postfix = output_postfix + self.converter = KeepLargestConnectedComponent(applied_values, independent, background, connectivity) + + def __call__(self, data): + d = dict(data) + for idx, key in enumerate(self.keys): + d[f"{key}_{self.output_postfix}"] = self.converter(d[key]) + return d + + SplitChannelD = SplitChannelDict = SplitChanneld +KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 09abdd70ad..9a45abb2b0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -12,7 +12,9 @@ import random import warnings +import torch import numpy as np +from skimage import measure from monai.utils.misc import ensure_tuple @@ -395,3 +397,22 @@ def generate_spatial_bounding_box(img, select_fn=lambda x: x > 0, channel_indexe box_start.append(max(0, np.min(nonzero_idx[i]) - margin)) box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1)) return box_start, box_end + + +def get_largest_connected_component_mask(img, connectivity=None): + """ + Gets the largest connected component mask of an image. + + Args: + img: Image to get largest connected component from. Shape is (batch_size, spatial_dim1 [, spatial_dim2, ...]) + connectivity (int): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. If ``None``, a full + connectivity of ``input.ndim`` is used. + """ + img_arr = img.detach().cpu().numpy() + largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) + for i, item in enumerate(img_arr): + item = measure.label(item, connectivity=connectivity) + if item.max() != 0: + largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1) + return torch.as_tensor(largest_cc, device=img.device) diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py new file mode 100644 index 0000000000..3d2d1a2692 --- /dev/null +++ b/tests/test_keep_largest_connected_component.py @@ -0,0 +1,145 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import KeepLargestConnectedComponent + +grid_1 = torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]) +grid_2 = torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]) + +TEST_CASE_1 = [ + "value_1", + {"independent": False, "applied_values": [1]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_2 = [ + "value_2", + {"independent": False, "applied_values": [2]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_3 = [ + "independent_value_1_2", + {"independent": True, "applied_values": [1, 2]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_4 = [ + "dependent_value_1_2", + {"independent": False, "applied_values": [1, 2]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_5 = [ + "value_1", + {"independent": True, "applied_values": [1]}, + grid_2, + torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), +] + +TEST_CASE_6 = [ + "independent_value_1_2", + {"independent": True, "applied_values": [1, 2]}, + grid_2, + torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), +] + +TEST_CASE_7 = [ + "dependent_value_1_2", + {"independent": False, "applied_values": [1, 2]}, + grid_2, + torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]), +] + +TEST_CASE_8 = [ + "value_1_connect_1", + {"independent": False, "applied_values": [1], "connectivity": 1}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_9 = [ + "independent_value_1_2_connect_1", + {"independent": True, "applied_values": [1, 2], "connectivity": 1}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_10 = [ + "dependent_value_1_2_connect_1", + {"independent": False, "applied_values": [1, 2], "connectivity": 1}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_11 = [ + "value_0_background_3", + {"independent": False, "applied_values": [0], "background": 3}, + grid_1, + torch.tensor([[[[3, 3, 1, 3, 3], [3, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_12 = [ + "all_0_batch_2", + {"independent": False, "applied_values": [1], "background": 3}, + torch.tensor( + [ + [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + [[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0]]], + ] + ), + torch.tensor( + [ + [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + [[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 3, 3, 3], [0, 0, 3, 0, 0], [0, 0, 0, 0, 0]]], + ] + ), +] + +VALID_CASES = [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, +] + + +class TestKeepLargestConnectedComponent(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, tensor, expected): + converter = KeepLargestConnectedComponent(**args) + if torch.cuda.is_available(): + result = converter(tensor.clone().cuda()) + assert torch.allclose(result, expected.cuda()) + else: + result = converter(tensor.clone()) + assert torch.allclose(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py new file mode 100644 index 0000000000..32bc5ff201 --- /dev/null +++ b/tests/test_keep_largest_connected_componentd.py @@ -0,0 +1,152 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import KeepLargestConnectedComponentd + +grid_1 = { + "img": torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]) +} +grid_2 = { + "img": torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]) +} + +TEST_CASE_1 = [ + "value_1", + {"keys": ["img"], "independent": False, "applied_values": [1]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_2 = [ + "value_2", + {"keys": ["img"], "independent": False, "applied_values": [2]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_3 = [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_values": [1, 2]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_4 = [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_values": [1, 2]}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_5 = [ + "value_1", + {"keys": ["img"], "independent": True, "applied_values": [1]}, + grid_2, + torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), +] + +TEST_CASE_6 = [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_values": [1, 2]}, + grid_2, + torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), +] + +TEST_CASE_7 = [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_values": [1, 2]}, + grid_2, + torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]), +] + +TEST_CASE_8 = [ + "value_1_connect_1", + {"keys": ["img"], "independent": False, "applied_values": [1], "connectivity": 1}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_9 = [ + "independent_value_1_2_connect_1", + {"keys": ["img"], "independent": True, "applied_values": [1, 2], "connectivity": 1}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_10 = [ + "dependent_value_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_values": [1, 2], "connectivity": 1}, + grid_1, + torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), +] + +TEST_CASE_11 = [ + "value_0_background_3", + {"keys": ["img"], "independent": False, "applied_values": [0], "background": 3}, + grid_1, + torch.tensor([[[[3, 3, 1, 3, 3], [3, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), +] + +TEST_CASE_12 = [ + "all_0_batch_2", + {"keys": ["img"], "independent": False, "applied_values": [1], "background": 3}, + { + "img": torch.tensor( + [ + [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + [[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0]]], + ] + ) + }, + torch.tensor( + [ + [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], + [[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 3, 3, 3], [0, 0, 3, 0, 0], [0, 0, 0, 0, 0]]], + ] + ), +] + +VALID_CASES = [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, +] + + +class TestKeepLargestConnectedComponentd(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_dict, expected): + converter = KeepLargestConnectedComponentd(**args) + if torch.cuda.is_available(): + input_dict["img"] = input_dict["img"].cuda() + result = converter(input_dict) + torch.allclose(result["img_largestcc"], expected.cuda()) + else: + result = converter(input_dict) + torch.allclose(result["img_largestcc"], expected) + + +if __name__ == "__main__": + unittest.main()