Skip to content

Commit dfd52bb

Browse files
authored
Revert invertibility of RescaleIntensity (#1116)
* Add failing test * Revert invertibility of `RescaleIntensity` (#1120) * Make RescaleIntensity non-invertible again * Disable support to invert RescaleIntensity
1 parent 80d71e0 commit dfd52bb

File tree

3 files changed

+32
-39
lines changed

3 files changed

+32
-39
lines changed

src/torchio/transforms/preprocessing/intensity/rescale.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from ....data.subject import Subject
8-
from ....typing import TypeRangeFloat
8+
from ....typing import TypeDoubleFloat
99
from .normalization_transform import NormalizationTransform
1010
from .normalization_transform import TypeMaskingMethod
1111

@@ -45,10 +45,10 @@ class RescaleIntensity(NormalizationTransform):
4545

4646
def __init__(
4747
self,
48-
out_min_max: TypeRangeFloat = (0, 1),
49-
percentiles: TypeRangeFloat = (0, 100),
48+
out_min_max: TypeDoubleFloat = (0, 1),
49+
percentiles: TypeDoubleFloat = (0, 100),
5050
masking_method: TypeMaskingMethod = None,
51-
in_min_max: Optional[TypeRangeFloat] = None,
51+
in_min_max: Optional[TypeDoubleFloat] = None,
5252
**kwargs,
5353
):
5454
super().__init__(masking_method=masking_method, **kwargs)
@@ -65,24 +65,18 @@ def __init__(
6565
max_constraint=100,
6666
)
6767

68-
self.in_min: Optional[float]
69-
self.in_max: Optional[float]
7068
if self.in_min_max is not None:
71-
self.in_min, self.in_max = self._parse_range(
69+
self.in_min_max = self._parse_range(
7270
self.in_min_max,
7371
'in_min_max',
7472
)
75-
else:
76-
self.in_min = None
77-
self.in_max = None
7873

7974
self.args_names = [
8075
'out_min_max',
8176
'percentiles',
8277
'masking_method',
8378
'in_min_max',
8479
]
85-
self.invert_transform = False
8680

8781
def apply_normalization(
8882
self,
@@ -109,34 +103,28 @@ def rescale(
109103
)
110104
warnings.warn(message, RuntimeWarning, stacklevel=2)
111105
return tensor
106+
112107
values = array[mask]
113108
cutoff = np.percentile(values, self.percentiles)
114109
np.clip(array, *cutoff, out=array) # type: ignore[call-overload]
110+
115111
if self.in_min_max is None:
116-
self.in_min_max = self._parse_range(
117-
(array.min(), array.max()),
118-
'in_min_max',
119-
)
120-
self.in_min, self.in_max = self.in_min_max
121-
assert self.in_min is not None
122-
assert self.in_max is not None
123-
in_range = self.in_max - self.in_min
112+
in_min, in_max = array.min(), array.max()
113+
else:
114+
in_min, in_max = self.in_min_max
115+
in_range = in_max - in_min
124116
if in_range == 0: # should this be compared using a tolerance?
125117
message = (
126118
f'Rescaling image "{image_name}" not possible'
127119
' because all the intensity values are the same'
128120
)
129121
warnings.warn(message, RuntimeWarning, stacklevel=2)
130122
return tensor
123+
131124
out_range = self.out_max - self.out_min
132-
if self.invert_transform:
133-
array -= self.out_min
134-
array /= out_range
135-
array *= in_range
136-
array += self.in_min
137-
else:
138-
array -= self.in_min
139-
array /= in_range
140-
array *= out_range
141-
array += self.out_min
125+
126+
array -= in_min
127+
array /= in_range
128+
array *= out_range
129+
array += self.out_min
142130
return torch.as_tensor(array)

src/torchio/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TypeQuartetInt = Tuple[int, int, int, int]
2222
TypeSextetInt = Tuple[int, int, int, int, int, int]
2323

24+
TypeDoubleFloat = Tuple[float, float]
2425
TypeTripletFloat = Tuple[float, float, float]
2526
TypeSextetFloat = Tuple[float, float, float, float, float, float]
2627

tests/transforms/preprocessing/test_rescale.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,17 @@ def test_empty_mask(self):
109109
with pytest.warns(RuntimeWarning):
110110
rescale(subject)
111111

112-
def test_invert_rescaling(self):
113-
torch.manual_seed(0)
114-
transform = tio.RescaleIntensity(out_min_max=(0, 1))
115-
data = torch.rand(1, 2, 3, 4).double()
116-
subject = tio.Subject(t1=tio.ScalarImage(tensor=data))
117-
transformed = transform(subject)
118-
assert transformed.t1.data.min() == 0
119-
assert transformed.t1.data.max() == 1
120-
inverted = transformed.apply_inverse_transform()
121-
self.assert_tensor_almost_equal(inverted.t1.data, data)
112+
def test_persistent_in_min_max(self):
113+
# see https://github.com/fepegar/torchio/issues/1115
114+
img1 = torch.tensor([[[[0, 1]]]])
115+
img2 = torch.tensor([[[[0, 10]]]])
116+
117+
rescale = tio.RescaleIntensity(out_min_max=(0, 1))
118+
119+
assert rescale(img1).data.flatten().tolist() == [0, 1]
120+
assert rescale(img2).data.flatten().tolist() == [0, 1]
121+
122+
rescale = tio.RescaleIntensity(out_min_max=(0, 1))
123+
124+
assert rescale(img2).data.flatten().tolist() == [0, 1]
125+
assert rescale(img1).data.flatten().tolist() == [0, 1]

0 commit comments

Comments
 (0)