-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy path99_plot_errors_over_features.py
56 lines (46 loc) · 1.79 KB
/
99_plot_errors_over_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Plot standardized errors over features for all components.
All features are standardized.
Observe the behavior of model and projection errors over the parameter space.
The errors are likely to be worse at the extremes of the space.
"""
# Author: Arturs Berzins <berzins@cats.rwth-aachen.de>
# License: BSD 3 clause
import config
import utils
from matplotlib import pyplot
dataset = 'test'
model_keys = ['RBF',
'GPR',
'FNN',
'' # projection error
]
df = utils.load_error_table(dataset)
P = len(config.mu_names)
fig, axes = pyplot.subplots(len(config.components),P,sharex='col',sharey='row')
fig.suptitle('Standardized errors over standardized features')
cmap = pyplot.get_cmap('tab10')
features = utils.load_features(dataset)
for idx_ax, component in enumerate(config.components):
L = config.num_basis[component]
df_filtered = df.loc[ (df['component']==component) &
(df['l']==L)]
for p in range(P):
xs = features[:,p]
for i, model_key in enumerate(model_keys):
ys = df_filtered[F'eps_pod{model_key.lower()}_sq'].values
ys = ys ** 0.5
# Plot transparent data due to bug in scatter limits
axes[idx_ax,p].plot(xs, ys, alpha=0)
axes[idx_ax,p].scatter(xs, ys, marker='o', s=2, color=cmap(i), label=F'POD-{model_key}')
# Label x axes
for p in range(P):
axes[len(config.components)-1,p].set_xlabel(config.mu_names[p])
# Label y axes
for idx_ax, component in enumerate(config.components):
axes[idx_ax,0].set_ylabel(F'{component}')
handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(model_keys))
fig.set_size_inches(w=6.3, h=4.3)
fig.subplots_adjust(bottom=.2)
pyplot.show()