forked from usuyama/pytorch-unet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
46 lines (32 loc) · 1.43 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import matplotlib.pyplot as plt
import numpy as np
def plot_img_array(img_array, ncol=3):
nrow = len(img_array) // ncol
f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
for i in range(len(img_array)):
plots[i // ncol, i % ncol]
plots[i // ncol, i % ncol].imshow(img_array[i])
from functools import reduce
def plot_side_by_side(img_arrays):
flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
plot_img_array(np.array(flatten_list), ncol=len(img_arrays))
import itertools
def plot_errors(results_dict, title):
markers = itertools.cycle(('+', 'x', 'o'))
plt.title('{}'.format(title))
for label, result in sorted(results_dict.items()):
plt.plot(result, marker=next(markers), label=label)
plt.ylabel('dice_coef')
plt.xlabel('epoch')
plt.legend(loc=3, bbox_to_anchor=(1, 0))
plt.show()
def masks_to_colorimg(masks):
colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
channels, height, width = masks.shape
for y in range(height):
for x in range(width):
selected_colors = colors[masks[:,y,x] > 0.5]
if len(selected_colors) > 0:
colorimg[y,x,:] = np.mean(selected_colors, axis=0)
return colorimg.astype(np.uint8)