From 82ee0385e444861e8ec6b1a245dcd83d35d71c6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Wed, 1 Nov 2023 22:47:13 +0000 Subject: [PATCH] Make RescaleIntensity non-invertible again --- .../preprocessing/intensity/rescale.py | 27 ++++++++----------- .../transforms/preprocessing/test_rescale.py | 11 -------- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/src/torchio/transforms/preprocessing/intensity/rescale.py b/src/torchio/transforms/preprocessing/intensity/rescale.py index 86e8d720..de121aef 100644 --- a/src/torchio/transforms/preprocessing/intensity/rescale.py +++ b/src/torchio/transforms/preprocessing/intensity/rescale.py @@ -65,16 +65,11 @@ def __init__( max_constraint=100, ) - self.in_min: Optional[float] - self.in_max: Optional[float] if self.in_min_max is not None: - self.in_min, self.in_max = self._parse_range( + self.in_min_max = self._parse_range( self.in_min_max, 'in_min_max', ) - else: - self.in_min = None - self.in_max = None self.args_names = [ 'out_min_max', @@ -109,18 +104,16 @@ def rescale( ) warnings.warn(message, RuntimeWarning, stacklevel=2) return tensor + values = array[mask] cutoff = np.percentile(values, self.percentiles) np.clip(array, *cutoff, out=array) # type: ignore[call-overload] + if self.in_min_max is None: - self.in_min_max = self._parse_range( - (array.min(), array.max()), - 'in_min_max', - ) - self.in_min, self.in_max = self.in_min_max - assert self.in_min is not None - assert self.in_max is not None - in_range = self.in_max - self.in_min + in_min, in_max = array.min(), array.max() + else: + in_min, in_max = self.in_min_max + in_range = in_max - in_min if in_range == 0: # should this be compared using a tolerance? message = ( f'Rescaling image "{image_name}" not possible' @@ -128,14 +121,16 @@ def rescale( ) warnings.warn(message, RuntimeWarning, stacklevel=2) return tensor + out_range = self.out_max - self.out_min + if self.invert_transform: array -= self.out_min array /= out_range array *= in_range - array += self.in_min + array += in_min else: - array -= self.in_min + array -= in_min array /= in_range array *= out_range array += self.out_min diff --git a/tests/transforms/preprocessing/test_rescale.py b/tests/transforms/preprocessing/test_rescale.py index 8cb00485..b9d205a8 100644 --- a/tests/transforms/preprocessing/test_rescale.py +++ b/tests/transforms/preprocessing/test_rescale.py @@ -109,17 +109,6 @@ def test_empty_mask(self): with pytest.warns(RuntimeWarning): rescale(subject) - def test_invert_rescaling(self): - torch.manual_seed(0) - transform = tio.RescaleIntensity(out_min_max=(0, 1)) - data = torch.rand(1, 2, 3, 4).double() - subject = tio.Subject(t1=tio.ScalarImage(tensor=data)) - transformed = transform(subject) - assert transformed.t1.data.min() == 0 - assert transformed.t1.data.max() == 1 - inverted = transformed.apply_inverse_transform() - self.assert_tensor_almost_equal(inverted.t1.data, data) - def test_persistent_in_min_max(self): # see https://github.com/fepegar/torchio/issues/1115 img1 = torch.tensor([[[[0, 1]]]])