Skip to content

Commit 6df9a7f

Browse files
committed
Add utils for models exploration
1 parent 1147aa0 commit 6df9a7f

File tree

3 files changed

+124
-10
lines changed

3 files changed

+124
-10
lines changed

src/metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,27 @@ def mean_iou(y_true, y_pred):
133133
y_pred = K.cast(K.greater(y_pred, t), dtype='float32')
134134
score = iou(y_true, y_pred)
135135
prec.append(score)
136+
137+
138+
def mean_dice(y_true, y_pred):
139+
"""Mean of dice score.
140+
Thresholds for mask are from 0.5, to 1.0 with a step 0.5.
141+
142+
Args:
143+
y_true (numpy.array/tf.tensor):
144+
b x X x Y( x Z...) x c One hot encoding of ground truth.
145+
y_pred (numpy.array/tf.tensor):
146+
b x X x Y( x Z...) x c Network output, must sum to 1 over c channel.
147+
Returns:
148+
(tf.tensor): Mean IoU coefficient.
149+
"""
150+
151+
prec = []
152+
for t in np.arange(0.5, 1.0, 0.05):
153+
y_pred = K.cast(K.greater(y_pred, t), dtype='float32')
154+
y_pred = K.constant(y_pred)
155+
y_true = K.constant(y_true)
156+
score = dice_coefficient(y_true, y_pred)
157+
prec.append(score)
158+
159+
return K.mean(K.stack(prec), axis=0)

src/utils.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,92 @@ def get_dataset_split(data_paths, train_ratio=0.7, seed=42, shuffle=True):
170170
test_data_paths = test_data_paths[n_test:]
171171

172172
return train_data_paths, test_data_paths, val_data_paths
173+
174+
175+
def stack_2D_2_3D(samples_sep, arr, dim, n_channels):
176+
stacked_pred = []
177+
prev_idx = 0
178+
for sep in samples_sep:
179+
pred = arr[prev_idx:prev_idx+sep]
180+
if pred.shape[0] > dim[0]:
181+
depth_shift = int((pred.shape[0] - dim[0]) // 2)+1
182+
pred = pred[depth_shift:-depth_shift,:,:,:]
183+
pred = np.moveaxis(pred, -1, 0)
184+
_, pred = pad(pred, pred, pred.shape[1:], dim, n_channels, n_channels)
185+
pred = np.moveaxis(pred, 0, -1)
186+
stacked_pred.append(pred)
187+
prev_idx += sep + 1
188+
189+
return np.array(stacked_pred)
190+
191+
192+
def load_sample(path, dim, scan_types, classes, merge_classes, mode="3D"):
193+
dim_before_axes_swap = (dim[-1], dim[1], dim[0])
194+
if mode=="3D":
195+
masks = preprocess_label(np.asanyarray(nib.load(path['seg']).dataobj), output_classes=classes, merge_classes=merge_classes)
196+
imgs = np.array([np.asanyarray(nib.load(path[m]).dataobj) for m in scan_types])
197+
elif mode=="2D":
198+
masks = preprocess_label(np.load(path['seg']), output_classes=classes, merge_classes=merge_classes)
199+
imgs = np.array([np.load(path[m]) for m in scan_types], dtype=np.float16)
200+
imgs = np.moveaxis(imgs, [0, 1, 2, 3], [0, 3, 2, 1])
201+
masks = np.moveaxis(masks, [0, 1, 2, 3], [0, 3, 2, 1])
202+
203+
imgs, masks = crop(imgs, masks, depth=dim[0])
204+
imgs, masks = pad(imgs, masks, masks.shape[1:], dim_before_axes_swap, n_channels, n_classes)
205+
206+
imgs = change_orientation(imgs)
207+
masks = change_orientation(masks)
208+
209+
return masks, imgs
210+
211+
212+
def evaluate(data_paths, prediction, metric, dim, scan_types, classes, merge_classes, mode="3D"):
213+
scores = {'class': [], 'score': []}
214+
for path, pred in zip(data_paths, prediction):
215+
if merge_classes:
216+
load = ['mask']
217+
else:
218+
load = classes
219+
for cls in load:
220+
if merge_classes:
221+
cls = classes
222+
cls_name = 'mask'
223+
else:
224+
cls_name = cls
225+
cls = [cls]
226+
mask, _ = load_sample(path=path, dim=dim, scan_types=scan_types, classes=cls, merge_classes=merge_classes, mode=mode)
227+
mask = np.array([mask])
228+
pred = np.array([pred])
229+
score = metric(mask, pred)
230+
scores['class'] = scores['class'] + [cls_name]
231+
scores['score'] = scores['score'] + [score.numpy()]
232+
return scores
233+
234+
235+
def get_data_paths4existing_slit(data_dir, splitted_data, mode="2D"):
236+
data_paths = []
237+
for modalities in splitted_data:
238+
curr_case = {}
239+
for name, path in modalities.items():
240+
path = path.split('/')[-2:]
241+
if mode=="2D":
242+
ext = '.npy'
243+
elif mode=="3D":
244+
ext = '.nii.gz'
245+
path[-1] = path[-1].split('.')[0] + ext
246+
curr_case[name] = os.path.join(data_dir, *path)
247+
data_paths.append(curr_case)
248+
return data_paths
249+
250+
251+
def get_2D_sep_slices(paths_unnpacked):
252+
samples_sep = []
253+
prev_slice_idx = 0
254+
for i, path in enumerate(paths_unnpacked):
255+
if path['seg'][1] == 0:
256+
samples_sep.append(prev_slice_idx)
257+
prev_slice_idx = path['seg'][1]
258+
259+
samples_sep.append(prev_slice_idx)
260+
samples_sep.pop(0)
261+
return samples_sep

src/visualization.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,20 @@
66
plt.style.use('seaborn-pastel')
77

88

9-
def animate_scan(scan, mask):
10-
fig = plt.figure(figsize=(16, 8))
11-
ax1 = fig.add_subplot(1,2,1)
12-
ax2 = fig.add_subplot(1,2,2)
9+
def animate_scan(images:list, titles:list, figsize=(16, 8)):
10+
fig = plt.figure(figsize=figsize)
11+
plots_num = len(titles)
12+
axis = [fig.add_subplot(1, plots_num, i) for i in range(1, plots_num+1)]
1313

1414
myimages = []
15-
for i in range(scan.shape[0]):
16-
ax1.axis('off')
17-
ax1.set_title('Scan', fontsize='medium')
18-
ax2.axis('off')
19-
ax2.set_title('Mask', fontsize='medium')
15+
for i in range(images[0].shape[0]):
16+
curr_slice = []
17+
for ax_idx, (img, title) in enumerate(zip(images, titles)):
18+
axis[ax_idx].axis('off')
19+
axis[ax_idx].set_title(title, fontsize=18)
20+
curr_slice.append(axis[ax_idx].imshow(img[i], cmap='Greys_r'))
2021

21-
myimages.append([ax1.imshow(scan[i], cmap='Greys_r'), ax2.imshow(mask[i], cmap='Greys_r')])
22+
myimages.append(curr_slice)
2223

2324
anim = animation.ArtistAnimation(fig, myimages, interval=1000, blit=True, repeat_delay=1000)
2425
return anim

0 commit comments

Comments
 (0)