-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathplot_information.py
32 lines (24 loc) · 1.05 KB
/
plot_information.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 18})
def plot_information_plane(IXT_array, ITY_array, num_epochs, every_n):
assert len(IXT_array) == len(ITY_array)
max_index = len(IXT_array)
plt.figure(figsize=(15, 9))
plt.xlabel('$I(X;T)$')
plt.ylabel('$I(T;Y)$')
cmap = plt.get_cmap('gnuplot')
colors = [cmap(i) for i in np.linspace(0, 1, num_epochs + 1)]
for i in range(0, max_index):
IXT = IXT_array[i, :]
ITY = ITY_array[i, :]
plt.plot(IXT, ITY, marker='o', markersize=15, markeredgewidth=0.04,
linestyle=None, linewidth=1, color=colors[i * every_n], zorder=10)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1))
sm._A = []
cbar = plt.colorbar(sm, ticks=[])
cbar.set_label('Num epochs')
cbar.ax.text(0.5, -0.01, 0, transform=cbar.ax.transAxes, va='top', ha='center')
cbar.ax.text(0.5, 1.0, str(num_epochs), transform=cbar.ax.transAxes, va='bottom', ha='center')
plt.show()