-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_results.py
128 lines (98 loc) · 10.6 KB
/
graph_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import matplotlib.pyplot as plt
from matplotlib import ticker
import numpy as np
epochs = 20
trains = [[0.0, 0.3267372296364473, 0.0], [0.6190114068441065, 0.0, 0.006144393241167434], [0.6190657045195594, 0.0, 0.0], [0.6190657045195594, 0.0, 0.0], [0.3867219917012448, 0.0, 0.5578034682080925], [0.6295241968279789, 0.0, 0.2773722627737226], [0.6376811594202898, 0.0, 0.3441108545034642], [0.619718309859155, 0.0, 0.01529051987767584], [0.21649484536082472, 0.0, 0.5417568152315015], [0.6294416243654822, 0.0, 0.1527777777777778], [0.5680190930787589, 0.0, 0.5619937694704049], [0.6092715231788078, 0.0, 0.5677331518039483], [0.6256291134339915, 0.0, 0.10888252148997134], [0.6375770020533881, 0.0, 0.5686421605401349], [0.6530223702998572, 0.0, 0.5271186440677965], [0.6442748091603052, 0.0, 0.5592705167173252], [0.6606451612903226, 0.0, 0.4372384937238494], [0.6449348044132397, 0.0, 0.5563325563325562], [0.5250501002004009, 0.0, 0.5717488789237668], [0.4403948367501898, 0.0, 0.5621181262729125], [0.6351706036745407, 0.0, 0.5639534883720931]]
devs = [[0.0, 0.923249299719888, 0.07250755287009063], [0.07649122807017544, 0.0, 0.06879606879606881], [0.07702523240371847, 0.0, 0.008163265306122448], [0.07927677329624479, 0.0, 0.06299212598425197], [0.1277860326894502, 0.0, 0.1671826625386997], [0.08827856792545365, 0.0, 0.1592775041050903], [0.09072781655034896, 0.0, 0.15827338129496404], [0.0800561797752809, 0.0, 0.07334963325183375], [0.0736842105263158, 0.0, 0.1532442125855885], [0.08702791461412153, 0.0, 0.13641900121802678], [0.1251117068811439, 0.0, 0.1739943872778298], [0.11203633610900834, 0.0, 0.1828512396694215], [0.08883647798742139, 0.0, 0.14866760168302945], [0.10489098408956983, 0.0, 0.17307692307692307], [0.11472868217054265, 0.0, 0.16338880484114976], [0.127027027027027, 0.0, 0.1429375351716376], [0.10986874088478366, 0.0, 0.17], [0.12700534759358287, 0.0, 0.1419647927314026], [0.12832263978001834, 0.0, 0.14866112650046168], [0.11946446961894953, 0.0, 0.1452318460192476], [0.12717536813922356, 0.0, 0.1429381735677822]]
train_a = [0.19526952695269528, 0.44884488448844884, 0.4482948294829483, 0.4482948294829483, 0.44664466446644663, 0.4884488448844885, 0.5055005500550055, 0.4504950495049505, 0.4020902090209021, 0.4735973597359736, 0.5099009900990099, 0.533003300330033, 0.46534653465346537, 0.5500550055005501, 0.5484048404840484, 0.5506050605060506, 0.5374037403740374, 0.5506050605060506, 0.4966996699669967, 0.46314631463146316, 0.5462046204620462]
dev_a = [0.8559614059269469, 0.042384562370778776, 0.04031702274293591, 0.043418332184700204, 0.08924879393521709, 0.06443831840110269, 0.06547208821502412, 0.04445210199862164, 0.0833907649896623, 0.055823569951757405, 0.08821502412129566, 0.0864920744314266, 0.057201929703652656, 0.0771881461061337, 0.07546519641626465, 0.07615437629221226, 0.0740868366643694, 0.07580978635423846, 0.07960027567195038, 0.0771881461061337, 0.07615437629221226]
losses = [1.0564017711759923, 1.0511277383500404, 1.0486017573665787, 1.0466578674840403, 1.0429198971161475, 1.0407251944908729, 1.036260547218742, 1.0379675515405424, 1.034066311605684, 1.0335237848889696, 1.0309524827606076, 1.0257282440478985, 1.0195907806302165, 1.0134979375116118, 1.0074656563145774, 1.0042665060404892, 0.9934646202312721, 0.9911167893435929, 0.9890997553919698, 0.9927226537531548]
trains2 = [[0.6187761307487647, 0.0, 0.0], [0.5779650812763396, 0.0, 0.5728395061728395], [0.6236559139784946, 0.0, 0.5944517833553501], [0.6782291191236879, 0.0, 0.5302752293577981], [0.6748466257668713, 0.0, 0.4804804804804805], [0.6852122986822841, 0.0, 0.6038961038961039], [0.6699453551912569, 0.005617977528089888, 0.6358620689655173], [0.7005270723526593, 0.0, 0.6113902847571189], [0.7014492753623187, 0.00558659217877095, 0.6324503311258278], [0.7156959526159921, 0.01652892561983471, 0.6575781876503609], [0.7147905098435134, 0.10416666666666667, 0.6750590086546027], [0.7139896373056994, 0.1941747572815534, 0.6970633693972179], [0.7339166237776634, 0.12371134020618556, 0.714176245210728], [0.7592689295039164, 0.30277185501066095, 0.7316293929712461], [0.7731778425655976, 0.3976377952755906, 0.7629157820240624], [0.6828908554572272, 0.5125748502994012, 0.7294117647058823], [0.7898089171974523, 0.5649546827794562, 0.7905982905982906], [0.755020080321285, 0.4749536178107607, 0.7610729881472239], [0.8333333333333333, 0.5747508305647842, 0.8238897396630934], [0.8497284248642125, 0.601823708206687, 0.8251324753974262], [0.8375959079283887, 0.6000000000000001, 0.829001367989056]]
devs2 = [[0.07761607761607761, 0.0, 0.0646900269541779], [0.11903012490815577, 0.0, 0.14029535864978904], [0.13114754098360656, 0.0, 0.15200478755236385], [0.12111468381564845, 0.0, 0.15384615384615385], [0.11126826968411127, 0.0, 0.1619718309859155], [0.1349009900990099, 0.0, 0.14868982327848876], [0.14304993252361672, 0.0, 0.14422535211267606], [0.1231578947368421, 0.0, 0.15622697126013266], [0.11468116658428076, 0.0007849293563579278, 0.15896188158961883], [0.12447257383966244, 0.007821666014861166, 0.15692079940784603], [0.11872146118721462, 0.0, 0.15552099533437017], [0.11653543307086614, 0.03775038520801233, 0.15656178050652342], [0.12083568605307735, 0.05494086226631057, 0.15722379603399433], [0.12159329140461216, 0.1226588321704003, 0.16368286445012786], [0.14407988587731813, 0.24019106107130675, 0.14547926580557444], [0.1743119266055046, 0.5863624303528787, 0.1723388848660391], [0.16384683882457704, 0.43791241751649673, 0.15738678544914622], [0.15739948674080412, 0.3107177974434611, 0.15151515151515152], [0.13993399339933993, 0.28495339547270304, 0.1540856031128405], [0.16732283464566927, 0.47305563646956017, 0.16826568265682657], [0.15689381933438984, 0.34371988435592676, 0.17774667599720081]]
train_a2 = [0.44774477447744776, 0.5192519251925193, 0.5506050605060506, 0.5676567656765676, 0.5555555555555556, 0.5907590759075908, 0.5913091309130913, 0.6028602860286029, 0.61001100110011, 0.6259625962596259, 0.6364136413641364, 0.6490649064906491, 0.6617161716171617, 0.6908690869086909, 0.7167216721672167, 0.6622662266226622, 0.7491749174917491, 0.7161716171617162, 0.7871287128712872, 0.7959295929592959, 0.7942794279427943]
dev_a2 = [0.042729152308752585, 0.0737422467263956, 0.07960027567195038, 0.07580978635423846, 0.07236388697450034, 0.07960027567195038, 0.08063404548587182, 0.0768435561681599, 0.0740868366643694, 0.08063404548587182, 0.07477601654031703, 0.09028256374913853, 0.09993108201240523, 0.13059958649207443, 0.19297036526533426, 0.4414197105444521, 0.31977946243969674, 0.23638869745003446, 0.21812543073742247, 0.3483804272915231, 0.2622329427980703]
losses2 = [1.0672731491235585, 1.0110030246304942, 0.9855081416093386, 0.9799992772904071, 0.9630664413446909, 0.9516144529148772, 0.9347212367005401, 0.9226315312988156, 0.8982541855874953, 0.8782883718773559, 0.8494735929992173, 0.8221145070843644, 0.7853949542392741, 0.767772784108644, 0.7223777454960477, 0.6994303440327173, 0.6593080117152288, 0.6521052077085108, 0.6017052645866687, 0.5581643641158774]
def graph_attention(model, word_to_ix, ix_to_word, batch, using_GPU):
(words, lengths), polarity, holder_target, label = batch.text, batch.polarity, batch.holder_target, batch.label
log_probs, attention = model(words, polarity, holder_target, lengths)
input_sentence = decode(words[:, 0], ix_to_word)
print(str(log_probs.data.max(1)[1]) + str(label))
print(input_sentence)
instance_attention = attention[:, 0].data.numpy()
print(instance_attention)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(instance_attention, cmap='bone')
fig.colorbar(cax)
# Set up axes
ax.set_yticklabels([''] + input_sentence)
ax.set_xticklabels([''])
# Show label at every tick
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
# plot across each epoch
def plot_f1(train_accs, dev_accs, test_accs=None):
train_accs = np.array(train_accs)
dev_accs = np.array(dev_accs)
positive_train_accs = train_accs[:, 2]
none_train_accs = train_accs[:, 1]
negative_train_accs = train_accs[:, 0]
positive_dev_accs = dev_accs[:, 2]
none_dev_accs = dev_accs[:, 1]
negative_dev_accs = dev_accs[:, 0]
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True)
plt.setp((ax1, ax2, ax3), xticks=range(0, epochs+1))
ax1.plot(range(0, epochs + 1), positive_train_accs, c='blue', label='Train Set')
ax1.plot(range(0, epochs + 1), positive_dev_accs, c='red', label='Dev Set')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Positive F1 Score')
ax1.set_title('Positive Sentiment F1 Score vs. # of Epochs')
plt.xticks(np.arange(0, epochs + 1, step=1))
ax2.plot(range(0, epochs + 1), none_train_accs, c='blue', label='Train Set')
ax2.plot(range(0, epochs + 1), none_dev_accs, c='red', label='Dev Set')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('None F1 Score')
ax2.set_title('No Sentiment F1 Score vs. # of Epochs')
plt.xticks(np.arange(0, epochs + 1, step=1))
ax3.plot(range(0, epochs + 1), negative_train_accs, c='blue', label='Train Set')
ax3.plot(range(0, epochs + 1), negative_dev_accs, c='red', label='Dev Set')
ax3.set_xlabel('Epochs')
ax3.set_ylabel('Negative F1 Score')
ax3.set_title('Negative Sentiment F1 Score vs. # of Epochs')
plt.xticks(np.arange(0, epochs + 1, step=1))
if test_accs is not None:
test_accs = np.array(test_accs)
positive_test_accs = test_accs[:, 2]
none_test_accs = test_accs[:, 1]
negative_test_accs = test_accs[:, 0]
ax1.plot(range(0, epochs + 1), positive_test_accs, c='green', label='Test Set')
ax2.plot(range(0, epochs + 1), none_test_accs, c='green', label='Test Set')
ax3.plot(range(0, epochs + 1), negative_test_accs, c='green', label='Test Set')
ax1.legend()
ax2.legend()
ax3.legend()
#fig.tight_layout()
fig.suptitle("F1 Scores")
plt.show()
def plot_accross_epochs(title, train_accs, dev_accs = None, test_accs = None):
epoch_start = 0
if title != "Accuracy":
epoch_start = 1
plt.plot(range(epoch_start, epochs + 1), train_accs, c='blue', label='Train Set')
plt.xlabel('Epochs')
plt.ylabel(title)
plt.title(title + ' vs. # of Epochs')
plt.xticks(np.arange(epoch_start, epochs + 1, step=1))
if dev_accs is not None:
plt.plot(range(epoch_start, epochs + 1), dev_accs, c='red', label='Dev Set')
if test_accs is not None:
plt.plot(range(epoch_start, epochs + 1), test_accs, c='green', label='Test Set')
plt.legend()
plt.show()
def main():
plot_f1(trains, devs)
plot_accross_epochs("Accuracy", train_a2, dev_a2)
plot_accross_epochs("NLLLoss", losses2)
if __name__ == "__main__":
main()