Skip to content

Commit d1bb483

Browse files
committed
Add visualization.py
1 parent e005d47 commit d1bb483

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

src/visualization.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from matplotlib import animation
4+
5+
6+
def animate_scan(scan, mask):
7+
fig = plt.figure(figsize=(16, 8))
8+
ax1 = fig.add_subplot(1,2,1)
9+
ax2 = fig.add_subplot(1,2,2)
10+
11+
myimages = []
12+
for i in range(scan.shape[0]):
13+
ax1.axis('off')
14+
ax1.set_title('Scan', fontsize='medium')
15+
ax2.axis('off')
16+
ax2.set_title('Mask', fontsize='medium')
17+
18+
myimages.append([ax1.imshow(scan[i], cmap='Greys_r'), ax2.imshow(mask[i], cmap='Greys_r')])
19+
20+
anim = animation.ArtistAnimation(fig, myimages, interval=1000, blit=True, repeat_delay=1000)
21+
return anim
22+
23+
24+
def show_class_frequency(classes_freq, classes2show, pixel2class):
25+
fig, ax = plt.subplots(figsize=(16,8))
26+
height = [v for k, v in classes_freq.items() if pixel2class[k] in classes2show]
27+
bars = [pixel2class[k] for k in classes_freq.keys() if pixel2class[k] in classes2show]
28+
plt.bar(bars, height, color=[(0.1, 0.1, 0.1, 0.1) for _ in range(len(bars))], edgecolor='blue')
29+
plt.title('Class frequency', fontsize=25)
30+
plt.xlabel('classes', fontsize=20)
31+
plt.ylabel('counts', fontsize=20)
32+
for i, v in enumerate(height):
33+
ax.text(i-len(height)*0.05, v, str(v), fontweight='bold', fontsize=15)
34+
plt.show()
35+
36+
37+
# Source https://gist.github.com/soply/f3eec2e79c165e39c9d540e916142ae1
38+
def show_images(images, cols = 1, scale=4, titles = None):
39+
"""Display a list of images in a single figure with matplotlib.
40+
41+
Parameters
42+
---------
43+
images: List of np.arrays compatible with plt.imshow.
44+
45+
cols (Default = 1): Number of columns in figure (number of rows is
46+
set to np.ceil(n_images/float(cols))).
47+
48+
titles: List of titles corresponding to each image. Must have
49+
the same length as titles.
50+
"""
51+
assert((titles is None)or (len(images) == len(titles)))
52+
n_images = len(images)
53+
if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
54+
fig = plt.figure()
55+
fig.tight_layout()
56+
for n, (image, title) in enumerate(zip(images, titles)):
57+
a = fig.add_subplot(np.ceil(n_images/float(cols)), cols, n + 1)
58+
a.axis('off')
59+
if image.ndim == 2:
60+
plt.gray()
61+
plt.imshow(image)
62+
#a.set_title(title)
63+
fig.set_size_inches(np.array(fig.get_size_inches()) * n_images / scale)
64+
plt.show()

0 commit comments

Comments
 (0)