-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodhyp.py
86 lines (71 loc) · 6.74 KB
/
modhyp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from itertools import cycle
import matplotlib.pylab as pylab
params = {'legend.fontsize': 'x-large',
'axes.labelsize': 'x-large',
'axes.titlesize':'x-large'}
pylab.rcParams.update(params)
n_classes = 3
historylog={}
historylog['diffuserun_allparams_lr0.002']=[0.37494173645973206, 0.38103339076042175, 0.38815373182296753, 0.3940505385398865, 0.3996098041534424, 0.4048088490962982, 0.4101080298423767, 0.41432443261146545, 0.4178105294704437, 0.4210216701030731, 0.4231182634830475, 0.425518274307251, 0.428148478269577, 0.42988744378089905, 0.4315766394138336, 0.43384242057800293, 0.43563228845596313, 0.43779465556144714, 0.43894514441490173, 0.4414037764072418, 0.44246217608451843, 0.44402045011520386, 0.4453774392604828, 0.44647836685180664, 0.4476519227027893, 0.4491664171218872, 0.44991055130958557, 0.45121487975120544, 0.4521588385105133, 0.45339465141296387]
historylog['diffuserun_allparams_lr0.0005']=[0.36500972509384155, 0.3743436932563782, 0.3768846094608307, 0.37887123227119446, 0.38142451643943787, 0.3836422562599182, 0.3861031234264374, 0.38834869861602783, 0.38990092277526855, 0.39175689220428467, 0.3947017788887024, 0.3968322277069092, 0.3983408808708191, 0.3996129035949707, 0.401181161403656, 0.4028477966785431, 0.40458494424819946, 0.4061996638774872, 0.40754660964012146, 0.4086054563522339, 0.4100553095340729, 0.4107266664505005, 0.4125533699989319, 0.4131374955177307, 0.4138016402721405, 0.41511014103889465, 0.41564035415649414, 0.4170118570327759, 0.41759851574897766, 0.41884592175483704]
historylog['diffuserun_allparams_ks2']=[0.365664005279541, 0.375478595495224, 0.37708839774131775, 0.3785148561000824, 0.37985095381736755, 0.3808770477771759, 0.38267725706100464, 0.38313472270965576, 0.38457798957824707, 0.3856242895126343, 0.3867718279361725, 0.38914987444877625, 0.3897300660610199, 0.39126068353652954, 0.3934861123561859, 0.3947852551937103, 0.39624035358428955, 0.39816176891326904, 0.3991550803184509, 0.401467889547348, 0.40235769748687744, 0.40346279740333557, 0.40421393513679504, 0.4057549238204956, 0.40749087929725647, 0.4079909324645996, 0.40876784920692444, 0.40887436270713806, 0.41037940979003906, 0.4110073745250702]
historylog['diffuserun_allparams_ks4']=[0.36926159262657166, 0.3772403299808502, 0.38443177938461304, 0.3916371762752533, 0.3977274000644684, 0.40360260009765625, 0.40713995695114136, 0.41112077236175537, 0.41380712389945984, 0.4162972569465637, 0.41832536458969116, 0.42112016677856445, 0.42285841703414917, 0.42488160729408264, 0.42720505595207214, 0.42873328924179077, 0.4307609498500824, 0.43188685178756714, 0.4333096742630005, 0.4344727694988251, 0.43570300936698914, 0.43725642561912537, 0.43811723589897156, 0.4397290349006653, 0.44016873836517334, 0.44135913252830505, 0.4422951340675354, 0.4430161714553833, 0.4445742070674896, 0.44488996267318726]
historylog['diffuserun_allparams_filters20']=[0.36666733026504517, 0.375679612159729, 0.3791179656982422, 0.38178539276123047, 0.384350448846817, 0.38596057891845703, 0.3882560729980469, 0.3901415169239044, 0.3916947543621063, 0.3937133550643921,
0.3953779935836792, 0.3973731994628906, 0.3986068069934845, 0.4002516567707062, 0.4025326669216156, 0.40446707606315613, 0.4055681824684143, 0.40746983885765076, 0.4089438319206238, 0.4108460247516632,
0.4113154113292694, 0.41374772787094116,0.4143563210964203, 0.4152584373950958, 0.41626253724098206, 0.41791632771492004, 0.41902223229408264, 0.4197980761528015, 0.42071643471717834, 0.422069787979126]
historylog['diffuserun_allparams']=[0.45948414836122115, 0.49084795078568977, 0.49763756260295194, 0.501176807780401, 0.5037291105385632, 0.5063421063660433, 0.5085519747034374, 0.509614771218029, 0.5117061503111298, 0.5128692673666255, 0.5135091864653715, 0.514661503643354, 0.5156502634930964, 0.5164280085368113, 0.5174994115213315, 0.5184042522148856, 0.5204329852318189, 0.5201267582523612, 0.5201338563483302, 0.5206700423797673, 0.5208075217959409, 0.5216498750265207, 0.5215222791328528, 0.5223640772371219, 0.5224544862580758, 0.5228850377963935, 0.5236136039782157, 0.5235770108245466, 0.5237585324573578, 0.524286328218445]
def historyplotter(history,axes,runname,label):
catacc=history[runname]
axes.plot(np.arange(len(catacc)),catacc,label=str(label),linewidth=4)
axes.legend(loc="lower right")
axes.set_xlabel('Training Epoch')
axes.set_ylabel('Categorical Accuracy')
def rocplotter(axes,runname,label):
tpr=np.load('/users/exet4487/confmatdata/'+str(runname)+'_tp.npy',allow_pickle=True)
fpr=np.load('/users/exet4487/confmatdata/'+str(runname)+'_fp.npy',allow_pickle=True)
roc_auc = dict()
lw=2
print(fpr,tpr)
for i in range(3):
print(str(i))
print(fpr.item().get(i))
roc_auc[i] = auc(fpr.item().get(i), tpr.item().get(i))
roc_auc["macro"] = auc(fpr.item().get("macro"), tpr.item().get("macro"))
print(tpr,fpr,type(tpr),type(fpr))
axes.plot(fpr.item().get("macro"), tpr.item().get("macro"),
label=str(label)+' Macro Average (AUC = {0:0.2f})'
''.format(roc_auc["macro"]),
linewidth=4)
axes.legend(loc="lower right")
axes.set_xlabel('FPR')
axes.set_ylabel('TPR')
from matplotlib.transforms import offset_copy
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15,7))
pad = 5 # in points
cols = range(0, 2)
xtitles = ['Training History','Final Performance']
for ax, col in zip(axes, cols):
ax.annotate(xtitles[col], xy=(0.5, 1), xytext=(0, pad),
xycoords='axes fraction', textcoords='offset points',
size='x-large', ha='center', va='baseline')
historyplotter(historylog,axes[0],'diffuserun_allparams_lr0.002','LR=0.002')
rocplotter(axes[1],'diffuserun_allparams_lr0.002','LR=0.002')
historyplotter(historylog,axes[0],'diffuserun_allparams_lr0.0005','LR=0.0005')
rocplotter(axes[1],'diffuserun_allparams_lr0.0005','LR=0.0005')
historyplotter(historylog,axes[0],'diffuserun_allparams_ks4','KS=4')
rocplotter(axes[1],'diffuserun_allparams_ks4','KS=4')
historyplotter(historylog,axes[0],'diffuserun_allparams_ks2','KS=2')
rocplotter(axes[1],'diffuserun_allparams_ks2','KS=2')
historyplotter(historylog,axes[0],'diffuserun_allparams_filters20','Filters=20')
rocplotter(axes[1],'diffuserun_allparams_filters20','Filters=20')
historyplotter(historylog,axes[0],'diffuserun_allparams','Default')
rocplotter(axes[1],'diffuserun_allparams','Default')
fig.tight_layout()
# tight_layout doesn't take these labels into account. We'll need
# to make some room. These numbers are are manually tweaked.
# You could automatically calculate them, but it's a pain.
plt.savefig('/users/exet4487/Figures/modhyp2.png')
#plt.show()