5
5
import torch
6
6
7
7
from ....data .subject import Subject
8
- from ....typing import TypeRangeFloat
8
+ from ....typing import TypeDoubleFloat
9
9
from .normalization_transform import NormalizationTransform
10
10
from .normalization_transform import TypeMaskingMethod
11
11
@@ -45,10 +45,10 @@ class RescaleIntensity(NormalizationTransform):
45
45
46
46
def __init__ (
47
47
self ,
48
- out_min_max : TypeRangeFloat = (0 , 1 ),
49
- percentiles : TypeRangeFloat = (0 , 100 ),
48
+ out_min_max : TypeDoubleFloat = (0 , 1 ),
49
+ percentiles : TypeDoubleFloat = (0 , 100 ),
50
50
masking_method : TypeMaskingMethod = None ,
51
- in_min_max : Optional [TypeRangeFloat ] = None ,
51
+ in_min_max : Optional [TypeDoubleFloat ] = None ,
52
52
** kwargs ,
53
53
):
54
54
super ().__init__ (masking_method = masking_method , ** kwargs )
@@ -65,24 +65,18 @@ def __init__(
65
65
max_constraint = 100 ,
66
66
)
67
67
68
- self .in_min : Optional [float ]
69
- self .in_max : Optional [float ]
70
68
if self .in_min_max is not None :
71
- self .in_min , self . in_max = self ._parse_range (
69
+ self .in_min_max = self ._parse_range (
72
70
self .in_min_max ,
73
71
'in_min_max' ,
74
72
)
75
- else :
76
- self .in_min = None
77
- self .in_max = None
78
73
79
74
self .args_names = [
80
75
'out_min_max' ,
81
76
'percentiles' ,
82
77
'masking_method' ,
83
78
'in_min_max' ,
84
79
]
85
- self .invert_transform = False
86
80
87
81
def apply_normalization (
88
82
self ,
@@ -109,34 +103,28 @@ def rescale(
109
103
)
110
104
warnings .warn (message , RuntimeWarning , stacklevel = 2 )
111
105
return tensor
106
+
112
107
values = array [mask ]
113
108
cutoff = np .percentile (values , self .percentiles )
114
109
np .clip (array , * cutoff , out = array ) # type: ignore[call-overload]
110
+
115
111
if self .in_min_max is None :
116
- self .in_min_max = self ._parse_range (
117
- (array .min (), array .max ()),
118
- 'in_min_max' ,
119
- )
120
- self .in_min , self .in_max = self .in_min_max
121
- assert self .in_min is not None
122
- assert self .in_max is not None
123
- in_range = self .in_max - self .in_min
112
+ in_min , in_max = array .min (), array .max ()
113
+ else :
114
+ in_min , in_max = self .in_min_max
115
+ in_range = in_max - in_min
124
116
if in_range == 0 : # should this be compared using a tolerance?
125
117
message = (
126
118
f'Rescaling image "{ image_name } " not possible'
127
119
' because all the intensity values are the same'
128
120
)
129
121
warnings .warn (message , RuntimeWarning , stacklevel = 2 )
130
122
return tensor
123
+
131
124
out_range = self .out_max - self .out_min
132
- if self .invert_transform :
133
- array -= self .out_min
134
- array /= out_range
135
- array *= in_range
136
- array += self .in_min
137
- else :
138
- array -= self .in_min
139
- array /= in_range
140
- array *= out_range
141
- array += self .out_min
125
+
126
+ array -= in_min
127
+ array /= in_range
128
+ array *= out_range
129
+ array += self .out_min
142
130
return torch .as_tensor (array )
0 commit comments