diff --git a/src/torchio/transforms/preprocessing/intensity/rescale.py b/src/torchio/transforms/preprocessing/intensity/rescale.py index 86e8d720..012cf22e 100644 --- a/src/torchio/transforms/preprocessing/intensity/rescale.py +++ b/src/torchio/transforms/preprocessing/intensity/rescale.py @@ -5,7 +5,7 @@ import torch from ....data.subject import Subject -from ....typing import TypeRangeFloat +from ....typing import TypeDoubleFloat from .normalization_transform import NormalizationTransform from .normalization_transform import TypeMaskingMethod @@ -45,10 +45,10 @@ class RescaleIntensity(NormalizationTransform): def __init__( self, - out_min_max: TypeRangeFloat = (0, 1), - percentiles: TypeRangeFloat = (0, 100), + out_min_max: TypeDoubleFloat = (0, 1), + percentiles: TypeDoubleFloat = (0, 100), masking_method: TypeMaskingMethod = None, - in_min_max: Optional[TypeRangeFloat] = None, + in_min_max: Optional[TypeDoubleFloat] = None, **kwargs, ): super().__init__(masking_method=masking_method, **kwargs) @@ -65,16 +65,11 @@ def __init__( max_constraint=100, ) - self.in_min: Optional[float] - self.in_max: Optional[float] if self.in_min_max is not None: - self.in_min, self.in_max = self._parse_range( + self.in_min_max = self._parse_range( self.in_min_max, 'in_min_max', ) - else: - self.in_min = None - self.in_max = None self.args_names = [ 'out_min_max', @@ -82,7 +77,6 @@ def __init__( 'masking_method', 'in_min_max', ] - self.invert_transform = False def apply_normalization( self, @@ -109,18 +103,16 @@ def rescale( ) warnings.warn(message, RuntimeWarning, stacklevel=2) return tensor + values = array[mask] cutoff = np.percentile(values, self.percentiles) np.clip(array, *cutoff, out=array) # type: ignore[call-overload] + if self.in_min_max is None: - self.in_min_max = self._parse_range( - (array.min(), array.max()), - 'in_min_max', - ) - self.in_min, self.in_max = self.in_min_max - assert self.in_min is not None - assert self.in_max is not None - in_range = self.in_max - self.in_min + in_min, in_max = array.min(), array.max() + else: + in_min, in_max = self.in_min_max + in_range = in_max - in_min if in_range == 0: # should this be compared using a tolerance? message = ( f'Rescaling image "{image_name}" not possible' @@ -128,15 +120,11 @@ def rescale( ) warnings.warn(message, RuntimeWarning, stacklevel=2) return tensor + out_range = self.out_max - self.out_min - if self.invert_transform: - array -= self.out_min - array /= out_range - array *= in_range - array += self.in_min - else: - array -= self.in_min - array /= in_range - array *= out_range - array += self.out_min + + array -= in_min + array /= in_range + array *= out_range + array += self.out_min return torch.as_tensor(array) diff --git a/src/torchio/typing.py b/src/torchio/typing.py index c1df6430..4a393db1 100644 --- a/src/torchio/typing.py +++ b/src/torchio/typing.py @@ -21,6 +21,7 @@ TypeQuartetInt = Tuple[int, int, int, int] TypeSextetInt = Tuple[int, int, int, int, int, int] +TypeDoubleFloat = Tuple[float, float] TypeTripletFloat = Tuple[float, float, float] TypeSextetFloat = Tuple[float, float, float, float, float, float] diff --git a/tests/transforms/preprocessing/test_rescale.py b/tests/transforms/preprocessing/test_rescale.py index 9fb1da87..b9d205a8 100644 --- a/tests/transforms/preprocessing/test_rescale.py +++ b/tests/transforms/preprocessing/test_rescale.py @@ -109,13 +109,17 @@ def test_empty_mask(self): with pytest.warns(RuntimeWarning): rescale(subject) - def test_invert_rescaling(self): - torch.manual_seed(0) - transform = tio.RescaleIntensity(out_min_max=(0, 1)) - data = torch.rand(1, 2, 3, 4).double() - subject = tio.Subject(t1=tio.ScalarImage(tensor=data)) - transformed = transform(subject) - assert transformed.t1.data.min() == 0 - assert transformed.t1.data.max() == 1 - inverted = transformed.apply_inverse_transform() - self.assert_tensor_almost_equal(inverted.t1.data, data) + def test_persistent_in_min_max(self): + # see https://github.com/fepegar/torchio/issues/1115 + img1 = torch.tensor([[[[0, 1]]]]) + img2 = torch.tensor([[[[0, 10]]]]) + + rescale = tio.RescaleIntensity(out_min_max=(0, 1)) + + assert rescale(img1).data.flatten().tolist() == [0, 1] + assert rescale(img2).data.flatten().tolist() == [0, 1] + + rescale = tio.RescaleIntensity(out_min_max=(0, 1)) + + assert rescale(img2).data.flatten().tolist() == [0, 1] + assert rescale(img1).data.flatten().tolist() == [0, 1]