Skip to content
This repository was archived by the owner on Sep 27, 2022. It is now read-only.

Commit 03015fa

Browse files
authored
Merge pull request #73 from Svdvoort/feature-image_output
feature-image_output
2 parents de8cc14 + cce7410 commit 03015fa

File tree

6 files changed

+429
-31
lines changed

6 files changed

+429
-31
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
"python.venvFolders": [
2020
"${workspaceFolder}/.venv"
2121
],
22-
"python.pythonPath": ".venv/bin/python"
22+
"python.pythonPath": ".venv/bin/python",
23+
"restructuredtext.languageServer.disabled": true
2324
}

PrognosAIs/IO/Configs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, config_settings: dict):
6363

6464
self.mask_keyword = "mask"
6565
self.max_cpus = 999
66+
self.output_channel_names=[]
6667
super().__init__(general_config)
6768

6869

@@ -73,6 +74,7 @@ def __init__(self, config_settings: dict):
7374
self.perform_step_on_image = True
7475
self.perform_step_on_patch = False
7576
self.extract_masks = False
77+
self.apply_to_output = False
7678

7779
super().__init__(config_settings)
7880

@@ -83,6 +85,7 @@ def __init__(self, config_settings: dict):
8385
self.crop_to_mask = False
8486
self.background_value = 0.0
8587
self.process_masks = True
88+
self.apply_to_output = False
8689
self._mask_file = None
8790
self._mask = None
8891

@@ -116,6 +119,7 @@ def mask(self, mask_file: str):
116119
class resampling_config(config):
117120
def __init__(self, config_settings: dict):
118121
self.resample_size = [0, 0, 0]
122+
self.apply_to_output = False
119123
(
120124
self.perform_step_on_image,
121125
self.perform_step_on_patch,
@@ -134,6 +138,7 @@ def __init__(self, config_settings: dict):
134138
self.mask_normalization = None
135139
self.normalization_method = None
136140
self.mask_smoothing = False
141+
self.apply_to_output = False
137142

138143
(
139144
self.perform_step_on_image,
@@ -165,6 +170,7 @@ def mask(self, mask_file: str):
165170

166171
class bias_field_correcting_config(config):
167172
def __init__(self, config_settings: dict):
173+
self.apply_to_output = False
168174
self._mask_file = None
169175
self._mask = None
170176

@@ -208,6 +214,8 @@ def __init__(self, config_settings: dict):
208214
self.perform_step_on_image = True
209215
self.perform_step_on_patch = False
210216

217+
self.apply_to_output = False
218+
211219
super().__init__(config_settings)
212220

213221
@property
@@ -223,6 +231,7 @@ class rejecting_config(config):
223231
def __init__(self, config_settings: dict):
224232
self.rejection_limit = 0
225233
self.rejection_as_label = False
234+
self.apply_to_output = False
226235
self._mask_file = None
227236
self._mask = None
228237
(

PrognosAIs/Preprocessing/Preprocessors.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def multi_dimension_extracting(self):
109109
if self.multi_dimension_extracting_config.extract_masks:
110110
self.sample.masks = (extraction_fuction, [max_dims])
111111

112+
if self.multi_dimension_extracting_config.apply_to_output:
113+
self.sample.output_channels = (extraction_fuction, [max_dims])
114+
112115
@staticmethod
113116
def _get_first_image_from_sequence(image: sitk.Image, max_dims: int) -> sitk.Image:
114117
"""
@@ -179,21 +182,29 @@ def masking(self):
179182
self.masking_config.mask,
180183
self.masking_config.background_value,
181184
self.masking_config.process_masks,
185+
self.masking_config.apply_to_output
182186
)
183187
if self.masking_config.crop_to_mask:
184-
self.crop_to_mask(self.masking_config.mask, self.masking_config.process_masks)
188+
self.crop_to_mask(self.masking_config.mask, self.masking_config.process_masks, self.masking_config.apply_to_output)
185189

186190
def mask_background(
187191
self, ROI_mask: sitk.Image, background_value: float = 0.0, process_masks: bool = True,
192+
apply_to_output: bool = False
188193
):
189194
mask_image_filter = sitk.MaskImageFilter()
190195

191196
mask_image_filter.SetMaskingValue(0)
192197
if background_value == "min":
193198
self.sample.channels = (self.mask_background_to_min, [ROI_mask])
199+
if apply_to_output:
200+
self.sample.output_channels = (self.mask_background_to_min, [ROI_mask])
201+
194202
else:
195203
mask_image_filter.SetOutsideValue(background_value)
196204
self.sample.channels = (mask_image_filter.Execute, [ROI_mask])
205+
if apply_to_output:
206+
self.sample.output_channels = (mask_image_filter.Execute, [ROI_mask])
207+
197208
if process_masks:
198209
# background_dtype = ImageSample.get_appropiate_dtype_from_scalar(background_value)
199210
# if background_dtype != self.sample.get_example_mask().GetPixelID():
@@ -215,7 +226,7 @@ def mask_background_to_min(image, mask):
215226
image = sitk.Mask(image, mask, img_min)
216227
return image
217228

218-
def crop_to_mask(self, ROI_mask: sitk.Image, process_masks: bool = True):
229+
def crop_to_mask(self, ROI_mask: sitk.Image, process_masks: bool = True, apply_to_output: bool = False):
219230
statics_image_filter = sitk.LabelShapeStatisticsImageFilter()
220231
statics_image_filter.Execute(ROI_mask)
221232

@@ -253,6 +264,12 @@ def resampling(self):
253264
[self.resampling_config.resample_size, mask_resampler],
254265
)
255266

267+
if self.resampling_config.apply_to_output:
268+
self.sample.output_channels = (
269+
self._resample,
270+
[self.resampling_config.resample_size, channel_resampler],
271+
)
272+
256273
@staticmethod
257274
def _resample(image, resample_size, resampler):
258275
original_size = np.asarray(image.GetSize())
@@ -287,6 +304,15 @@ def normalizing(self):
287304
self.normalizing_config.output_range,
288305
],
289306
)
307+
if self.normalizing_config.apply_to_output:
308+
self.sample.output_channels = (
309+
self._rescale_image_intensity_range,
310+
[
311+
self.normalizing_config.normalization_range,
312+
self.normalizing_config.output_range,
313+
],
314+
)
315+
290316
elif (
291317
self.normalizing_config.normalization_method == "range"
292318
and self.normalizing_config.mask is not None
@@ -299,11 +325,23 @@ def normalizing(self):
299325
self.normalizing_config.output_range,
300326
],
301327
)
328+
if self.normalizing_config.apply_to_output:
329+
self.sample.output_channels = (
330+
self._rescale_image_intensity_range_with_mask,
331+
[
332+
self.normalizing_config.mask,
333+
self.normalizing_config.normalization_range,
334+
self.normalizing_config.output_range,
335+
],
336+
)
337+
302338
elif (
303339
self.normalizing_config.normalization_method == "zscore"
304340
and self.normalizing_config.mask is None
305341
):
306342
self.sample.channels = self._zscore_image_intensity
343+
if self.normalizing_config.apply_to_output:
344+
self.sample.output_channels = self._zscore_image_intensity
307345
elif (
308346
self.normalizing_config.normalization_method == "zscore"
309347
and self.normalizing_config.mask is not None
@@ -312,6 +350,11 @@ def normalizing(self):
312350
self._zscore_image_intensity_with_mask,
313351
[self.normalizing_config.mask],
314352
)
353+
if self.normalizing_config.apply_to_output:
354+
self.sample.output_channels = (
355+
self._zscore_image_intensity_with_mask,
356+
[self.normalizing_config.mask],
357+
)
315358

316359
if self.normalizing_config.mask_normalization == "collapse":
317360
self.sample.masks = self._collapse_mask
@@ -480,6 +523,12 @@ def patching(self) -> None:
480523
[patch_parameters, 0, self.patching_config.patch_size],
481524
)
482525

526+
if self.patching_config.apply_to_output:
527+
self.sample.output_channels = (
528+
self._make_patches,
529+
[patch_parameters, self.patching_config.pad_constant, self.patching_config.patch_size],
530+
)
531+
483532
def _get_patch_parameters(self) -> dict:
484533
patch_parameters = {}
485534
patch_parameters["left_padding"] = np.zeros(self.sample.number_of_dimensions)
@@ -669,6 +718,8 @@ def rejecting(self):
669718
else:
670719
self.sample.channels = (self._get_accepted_patches, [rejection_status])
671720
self.sample.masks = (self._get_accepted_patches, [rejection_status])
721+
if self.rejecting_config.apply_to_output:
722+
self.sample.output_channels = (self._get_accepted_patches, [rejection_status])
672723

673724
return self.sample.number_of_patches > 0
674725

@@ -710,6 +761,8 @@ def bias_field_correcting(self):
710761
bias_field_corrector.SetUseMaskLabel(False)
711762
args = []
712763
self.sample.channels = (bias_field_corrector.Execute, args)
764+
if self.bias_field_correcting_config.apply_to_output:
765+
self.sample.output_channels = (bias_field_corrector.Execute, args)
713766

714767
# ===============================================================
715768
# Saving
@@ -735,7 +788,7 @@ def _convert_sitk_arrays_to_numpy(images: list):
735788
return np_array
736789

737790
def _patch_to_data_structure(
738-
self, patch_channels: list, patch_masks: list, patch_labels: list
791+
self, patch_channels: list, patch_output_channels: list, patch_masks: list, patch_labels: list
739792
) -> dict:
740793
N_channels = len(patch_channels)
741794
patch_channels = self._convert_sitk_arrays_to_numpy(patch_channels)
@@ -745,11 +798,19 @@ def _patch_to_data_structure(
745798
else:
746799
N_masks = 0
747800

801+
if patch_output_channels is not None:
802+
N_output_channels = len(patch_output_channels)
803+
patch_output_channels = self._convert_sitk_arrays_to_numpy(patch_output_channels)
804+
else:
805+
N_output_channels = 0
806+
748807
if self.saving_config.impute_missing_channels:
749808
patch_channels = self.channel_imputation(patch_channels)
750809

751810
if self.saving_config.save_as_float16:
752811
patch_channels = self.channels_to_float16(patch_channels)
812+
if N_output_channels > 0:
813+
patch_output_channels = self.channels_to_float16(patch_output_channels)
753814

754815
if self.saving_config.use_mask_as_channel and patch_masks is not None:
755816
patch_names = self.sample.channel_names + self.sample.mask_names
@@ -787,7 +848,7 @@ def _patch_to_data_structure(
787848
data_structure[self.saving_config.label_npz_keyword] = dict(
788849
zip(self.sample.mask_names, np.split(patch_masks, N_masks, axis=-1),)
789850
)
790-
elif len(patch_labels) > 0:
851+
elif len(patch_labels) > 0 or N_output_channels > 0:
791852
# If we have other labels as well we ensure that we
792853
# Give a name to the mask, as we have multi-outputs
793854
data_structure[self.saving_config.label_npz_keyword] = {
@@ -818,6 +879,19 @@ def _patch_to_data_structure(
818879
data_structure[self.saving_config.label_npz_keyword] = {
819880
self.saving_config.label_npz_keyword: patch_labels[label_keys[0]]
820881
}
882+
883+
if N_output_channels > 0:
884+
output_channel_structure = dict(
885+
zip(self.sample.output_channel_names, np.split(patch_output_channels, N_output_channels, axis=-1),)
886+
)
887+
if len(patch_labels) > 0 or self.saving_config.use_mask_as_label:
888+
# If we have other labels as well we ensure that we
889+
# Give a name to the mask, as we have multi-outputs
890+
data_structure[self.saving_config.label_npz_keyword].update(output_channel_structure)
891+
else:
892+
data_structure[self.saving_config.label_npz_keyword] = output_channel_structure
893+
894+
821895
return data_structure
822896

823897
def _get_number_of_classes(self, data_structure: dict):
@@ -932,6 +1006,11 @@ def saving(self):
9321006
# First convert to a dict for easy saving in npz format
9331007
sample_channels = self.sample.get_grouped_channels()
9341008

1009+
if self.sample.has_output_channels:
1010+
sample_output_channels = self.sample.get_grouped_output_channels()
1011+
else:
1012+
sample_output_channels = [None] * len(sample_channels)
1013+
9351014
if self.sample.has_masks:
9361015
sample_masks = self.sample.get_grouped_masks()
9371016
else:
@@ -941,11 +1020,11 @@ def saving(self):
9411020
sample_labels = self.sample.labels
9421021

9431022
data_structure_patches = []
944-
for i_channel_patch, i_mask_patch, i_label_patch in zip(
945-
sample_channels, sample_masks, sample_labels
1023+
for i_channel_patch, i_output_channel_patch, i_mask_patch, i_label_patch in zip(
1024+
sample_channels, sample_output_channels, sample_masks, sample_labels
9461025
):
9471026
data_structure_patches.append(
948-
self._patch_to_data_structure(i_channel_patch, i_mask_patch, i_label_patch)
1027+
self._patch_to_data_structure(i_channel_patch, i_output_channel_patch, i_mask_patch, i_label_patch)
9491028
)
9501029

9511030
if data_structure_patches:
@@ -1003,7 +1082,7 @@ def __init__(self, samples_path: str, output_directory: str, config: dict):
10031082

10041083
def _run_single_sample(self, sample_directory: str):
10051084
sample = self.sample_class(
1006-
root_path=sample_directory, mask_keyword=self.general_config.mask_keyword
1085+
root_path=sample_directory, mask_keyword=self.general_config.mask_keyword, output_channel_names=self.general_config.output_channel_names
10071086
)
10081087
print(sample.sample_name)
10091088

0 commit comments

Comments
 (0)