@@ -109,6 +109,9 @@ def multi_dimension_extracting(self):
109
109
if self .multi_dimension_extracting_config .extract_masks :
110
110
self .sample .masks = (extraction_fuction , [max_dims ])
111
111
112
+ if self .multi_dimension_extracting_config .apply_to_output :
113
+ self .sample .output_channels = (extraction_fuction , [max_dims ])
114
+
112
115
@staticmethod
113
116
def _get_first_image_from_sequence (image : sitk .Image , max_dims : int ) -> sitk .Image :
114
117
"""
@@ -179,21 +182,29 @@ def masking(self):
179
182
self .masking_config .mask ,
180
183
self .masking_config .background_value ,
181
184
self .masking_config .process_masks ,
185
+ self .masking_config .apply_to_output
182
186
)
183
187
if self .masking_config .crop_to_mask :
184
- self .crop_to_mask (self .masking_config .mask , self .masking_config .process_masks )
188
+ self .crop_to_mask (self .masking_config .mask , self .masking_config .process_masks , self . masking_config . apply_to_output )
185
189
186
190
def mask_background (
187
191
self , ROI_mask : sitk .Image , background_value : float = 0.0 , process_masks : bool = True ,
192
+ apply_to_output : bool = False
188
193
):
189
194
mask_image_filter = sitk .MaskImageFilter ()
190
195
191
196
mask_image_filter .SetMaskingValue (0 )
192
197
if background_value == "min" :
193
198
self .sample .channels = (self .mask_background_to_min , [ROI_mask ])
199
+ if apply_to_output :
200
+ self .sample .output_channels = (self .mask_background_to_min , [ROI_mask ])
201
+
194
202
else :
195
203
mask_image_filter .SetOutsideValue (background_value )
196
204
self .sample .channels = (mask_image_filter .Execute , [ROI_mask ])
205
+ if apply_to_output :
206
+ self .sample .output_channels = (mask_image_filter .Execute , [ROI_mask ])
207
+
197
208
if process_masks :
198
209
# background_dtype = ImageSample.get_appropiate_dtype_from_scalar(background_value)
199
210
# if background_dtype != self.sample.get_example_mask().GetPixelID():
@@ -215,7 +226,7 @@ def mask_background_to_min(image, mask):
215
226
image = sitk .Mask (image , mask , img_min )
216
227
return image
217
228
218
- def crop_to_mask (self , ROI_mask : sitk .Image , process_masks : bool = True ):
229
+ def crop_to_mask (self , ROI_mask : sitk .Image , process_masks : bool = True , apply_to_output : bool = False ):
219
230
statics_image_filter = sitk .LabelShapeStatisticsImageFilter ()
220
231
statics_image_filter .Execute (ROI_mask )
221
232
@@ -253,6 +264,12 @@ def resampling(self):
253
264
[self .resampling_config .resample_size , mask_resampler ],
254
265
)
255
266
267
+ if self .resampling_config .apply_to_output :
268
+ self .sample .output_channels = (
269
+ self ._resample ,
270
+ [self .resampling_config .resample_size , channel_resampler ],
271
+ )
272
+
256
273
@staticmethod
257
274
def _resample (image , resample_size , resampler ):
258
275
original_size = np .asarray (image .GetSize ())
@@ -287,6 +304,15 @@ def normalizing(self):
287
304
self .normalizing_config .output_range ,
288
305
],
289
306
)
307
+ if self .normalizing_config .apply_to_output :
308
+ self .sample .output_channels = (
309
+ self ._rescale_image_intensity_range ,
310
+ [
311
+ self .normalizing_config .normalization_range ,
312
+ self .normalizing_config .output_range ,
313
+ ],
314
+ )
315
+
290
316
elif (
291
317
self .normalizing_config .normalization_method == "range"
292
318
and self .normalizing_config .mask is not None
@@ -299,11 +325,23 @@ def normalizing(self):
299
325
self .normalizing_config .output_range ,
300
326
],
301
327
)
328
+ if self .normalizing_config .apply_to_output :
329
+ self .sample .output_channels = (
330
+ self ._rescale_image_intensity_range_with_mask ,
331
+ [
332
+ self .normalizing_config .mask ,
333
+ self .normalizing_config .normalization_range ,
334
+ self .normalizing_config .output_range ,
335
+ ],
336
+ )
337
+
302
338
elif (
303
339
self .normalizing_config .normalization_method == "zscore"
304
340
and self .normalizing_config .mask is None
305
341
):
306
342
self .sample .channels = self ._zscore_image_intensity
343
+ if self .normalizing_config .apply_to_output :
344
+ self .sample .output_channels = self ._zscore_image_intensity
307
345
elif (
308
346
self .normalizing_config .normalization_method == "zscore"
309
347
and self .normalizing_config .mask is not None
@@ -312,6 +350,11 @@ def normalizing(self):
312
350
self ._zscore_image_intensity_with_mask ,
313
351
[self .normalizing_config .mask ],
314
352
)
353
+ if self .normalizing_config .apply_to_output :
354
+ self .sample .output_channels = (
355
+ self ._zscore_image_intensity_with_mask ,
356
+ [self .normalizing_config .mask ],
357
+ )
315
358
316
359
if self .normalizing_config .mask_normalization == "collapse" :
317
360
self .sample .masks = self ._collapse_mask
@@ -480,6 +523,12 @@ def patching(self) -> None:
480
523
[patch_parameters , 0 , self .patching_config .patch_size ],
481
524
)
482
525
526
+ if self .patching_config .apply_to_output :
527
+ self .sample .output_channels = (
528
+ self ._make_patches ,
529
+ [patch_parameters , self .patching_config .pad_constant , self .patching_config .patch_size ],
530
+ )
531
+
483
532
def _get_patch_parameters (self ) -> dict :
484
533
patch_parameters = {}
485
534
patch_parameters ["left_padding" ] = np .zeros (self .sample .number_of_dimensions )
@@ -669,6 +718,8 @@ def rejecting(self):
669
718
else :
670
719
self .sample .channels = (self ._get_accepted_patches , [rejection_status ])
671
720
self .sample .masks = (self ._get_accepted_patches , [rejection_status ])
721
+ if self .rejecting_config .apply_to_output :
722
+ self .sample .output_channels = (self ._get_accepted_patches , [rejection_status ])
672
723
673
724
return self .sample .number_of_patches > 0
674
725
@@ -710,6 +761,8 @@ def bias_field_correcting(self):
710
761
bias_field_corrector .SetUseMaskLabel (False )
711
762
args = []
712
763
self .sample .channels = (bias_field_corrector .Execute , args )
764
+ if self .bias_field_correcting_config .apply_to_output :
765
+ self .sample .output_channels = (bias_field_corrector .Execute , args )
713
766
714
767
# ===============================================================
715
768
# Saving
@@ -735,7 +788,7 @@ def _convert_sitk_arrays_to_numpy(images: list):
735
788
return np_array
736
789
737
790
def _patch_to_data_structure (
738
- self , patch_channels : list , patch_masks : list , patch_labels : list
791
+ self , patch_channels : list , patch_output_channels : list , patch_masks : list , patch_labels : list
739
792
) -> dict :
740
793
N_channels = len (patch_channels )
741
794
patch_channels = self ._convert_sitk_arrays_to_numpy (patch_channels )
@@ -745,11 +798,19 @@ def _patch_to_data_structure(
745
798
else :
746
799
N_masks = 0
747
800
801
+ if patch_output_channels is not None :
802
+ N_output_channels = len (patch_output_channels )
803
+ patch_output_channels = self ._convert_sitk_arrays_to_numpy (patch_output_channels )
804
+ else :
805
+ N_output_channels = 0
806
+
748
807
if self .saving_config .impute_missing_channels :
749
808
patch_channels = self .channel_imputation (patch_channels )
750
809
751
810
if self .saving_config .save_as_float16 :
752
811
patch_channels = self .channels_to_float16 (patch_channels )
812
+ if N_output_channels > 0 :
813
+ patch_output_channels = self .channels_to_float16 (patch_output_channels )
753
814
754
815
if self .saving_config .use_mask_as_channel and patch_masks is not None :
755
816
patch_names = self .sample .channel_names + self .sample .mask_names
@@ -787,7 +848,7 @@ def _patch_to_data_structure(
787
848
data_structure [self .saving_config .label_npz_keyword ] = dict (
788
849
zip (self .sample .mask_names , np .split (patch_masks , N_masks , axis = - 1 ),)
789
850
)
790
- elif len (patch_labels ) > 0 :
851
+ elif len (patch_labels ) > 0 or N_output_channels > 0 :
791
852
# If we have other labels as well we ensure that we
792
853
# Give a name to the mask, as we have multi-outputs
793
854
data_structure [self .saving_config .label_npz_keyword ] = {
@@ -818,6 +879,19 @@ def _patch_to_data_structure(
818
879
data_structure [self .saving_config .label_npz_keyword ] = {
819
880
self .saving_config .label_npz_keyword : patch_labels [label_keys [0 ]]
820
881
}
882
+
883
+ if N_output_channels > 0 :
884
+ output_channel_structure = dict (
885
+ zip (self .sample .output_channel_names , np .split (patch_output_channels , N_output_channels , axis = - 1 ),)
886
+ )
887
+ if len (patch_labels ) > 0 or self .saving_config .use_mask_as_label :
888
+ # If we have other labels as well we ensure that we
889
+ # Give a name to the mask, as we have multi-outputs
890
+ data_structure [self .saving_config .label_npz_keyword ].update (output_channel_structure )
891
+ else :
892
+ data_structure [self .saving_config .label_npz_keyword ] = output_channel_structure
893
+
894
+
821
895
return data_structure
822
896
823
897
def _get_number_of_classes (self , data_structure : dict ):
@@ -932,6 +1006,11 @@ def saving(self):
932
1006
# First convert to a dict for easy saving in npz format
933
1007
sample_channels = self .sample .get_grouped_channels ()
934
1008
1009
+ if self .sample .has_output_channels :
1010
+ sample_output_channels = self .sample .get_grouped_output_channels ()
1011
+ else :
1012
+ sample_output_channels = [None ] * len (sample_channels )
1013
+
935
1014
if self .sample .has_masks :
936
1015
sample_masks = self .sample .get_grouped_masks ()
937
1016
else :
@@ -941,11 +1020,11 @@ def saving(self):
941
1020
sample_labels = self .sample .labels
942
1021
943
1022
data_structure_patches = []
944
- for i_channel_patch , i_mask_patch , i_label_patch in zip (
945
- sample_channels , sample_masks , sample_labels
1023
+ for i_channel_patch , i_output_channel_patch , i_mask_patch , i_label_patch in zip (
1024
+ sample_channels , sample_output_channels , sample_masks , sample_labels
946
1025
):
947
1026
data_structure_patches .append (
948
- self ._patch_to_data_structure (i_channel_patch , i_mask_patch , i_label_patch )
1027
+ self ._patch_to_data_structure (i_channel_patch , i_output_channel_patch , i_mask_patch , i_label_patch )
949
1028
)
950
1029
951
1030
if data_structure_patches :
@@ -1003,7 +1082,7 @@ def __init__(self, samples_path: str, output_directory: str, config: dict):
1003
1082
1004
1083
def _run_single_sample (self , sample_directory : str ):
1005
1084
sample = self .sample_class (
1006
- root_path = sample_directory , mask_keyword = self .general_config .mask_keyword
1085
+ root_path = sample_directory , mask_keyword = self .general_config .mask_keyword , output_channel_names = self . general_config . output_channel_names
1007
1086
)
1008
1087
print (sample .sample_name )
1009
1088
0 commit comments