@@ -170,3 +170,92 @@ def get_dataset_split(data_paths, train_ratio=0.7, seed=42, shuffle=True):
170
170
test_data_paths = test_data_paths [n_test :]
171
171
172
172
return train_data_paths , test_data_paths , val_data_paths
173
+
174
+
175
+ def stack_2D_2_3D (samples_sep , arr , dim , n_channels ):
176
+ stacked_pred = []
177
+ prev_idx = 0
178
+ for sep in samples_sep :
179
+ pred = arr [prev_idx :prev_idx + sep ]
180
+ if pred .shape [0 ] > dim [0 ]:
181
+ depth_shift = int ((pred .shape [0 ] - dim [0 ]) // 2 )+ 1
182
+ pred = pred [depth_shift :- depth_shift ,:,:,:]
183
+ pred = np .moveaxis (pred , - 1 , 0 )
184
+ _ , pred = pad (pred , pred , pred .shape [1 :], dim , n_channels , n_channels )
185
+ pred = np .moveaxis (pred , 0 , - 1 )
186
+ stacked_pred .append (pred )
187
+ prev_idx += sep + 1
188
+
189
+ return np .array (stacked_pred )
190
+
191
+
192
+ def load_sample (path , dim , scan_types , classes , merge_classes , mode = "3D" ):
193
+ dim_before_axes_swap = (dim [- 1 ], dim [1 ], dim [0 ])
194
+ if mode == "3D" :
195
+ masks = preprocess_label (np .asanyarray (nib .load (path ['seg' ]).dataobj ), output_classes = classes , merge_classes = merge_classes )
196
+ imgs = np .array ([np .asanyarray (nib .load (path [m ]).dataobj ) for m in scan_types ])
197
+ elif mode == "2D" :
198
+ masks = preprocess_label (np .load (path ['seg' ]), output_classes = classes , merge_classes = merge_classes )
199
+ imgs = np .array ([np .load (path [m ]) for m in scan_types ], dtype = np .float16 )
200
+ imgs = np .moveaxis (imgs , [0 , 1 , 2 , 3 ], [0 , 3 , 2 , 1 ])
201
+ masks = np .moveaxis (masks , [0 , 1 , 2 , 3 ], [0 , 3 , 2 , 1 ])
202
+
203
+ imgs , masks = crop (imgs , masks , depth = dim [0 ])
204
+ imgs , masks = pad (imgs , masks , masks .shape [1 :], dim_before_axes_swap , n_channels , n_classes )
205
+
206
+ imgs = change_orientation (imgs )
207
+ masks = change_orientation (masks )
208
+
209
+ return masks , imgs
210
+
211
+
212
+ def evaluate (data_paths , prediction , metric , dim , scan_types , classes , merge_classes , mode = "3D" ):
213
+ scores = {'class' : [], 'score' : []}
214
+ for path , pred in zip (data_paths , prediction ):
215
+ if merge_classes :
216
+ load = ['mask' ]
217
+ else :
218
+ load = classes
219
+ for cls in load :
220
+ if merge_classes :
221
+ cls = classes
222
+ cls_name = 'mask'
223
+ else :
224
+ cls_name = cls
225
+ cls = [cls ]
226
+ mask , _ = load_sample (path = path , dim = dim , scan_types = scan_types , classes = cls , merge_classes = merge_classes , mode = mode )
227
+ mask = np .array ([mask ])
228
+ pred = np .array ([pred ])
229
+ score = metric (mask , pred )
230
+ scores ['class' ] = scores ['class' ] + [cls_name ]
231
+ scores ['score' ] = scores ['score' ] + [score .numpy ()]
232
+ return scores
233
+
234
+
235
+ def get_data_paths4existing_slit (data_dir , splitted_data , mode = "2D" ):
236
+ data_paths = []
237
+ for modalities in splitted_data :
238
+ curr_case = {}
239
+ for name , path in modalities .items ():
240
+ path = path .split ('/' )[- 2 :]
241
+ if mode == "2D" :
242
+ ext = '.npy'
243
+ elif mode == "3D" :
244
+ ext = '.nii.gz'
245
+ path [- 1 ] = path [- 1 ].split ('.' )[0 ] + ext
246
+ curr_case [name ] = os .path .join (data_dir , * path )
247
+ data_paths .append (curr_case )
248
+ return data_paths
249
+
250
+
251
+ def get_2D_sep_slices (paths_unnpacked ):
252
+ samples_sep = []
253
+ prev_slice_idx = 0
254
+ for i , path in enumerate (paths_unnpacked ):
255
+ if path ['seg' ][1 ] == 0 :
256
+ samples_sep .append (prev_slice_idx )
257
+ prev_slice_idx = path ['seg' ][1 ]
258
+
259
+ samples_sep .append (prev_slice_idx )
260
+ samples_sep .pop (0 )
261
+ return samples_sep
0 commit comments