-
Notifications
You must be signed in to change notification settings - Fork 2
/
06_baselines.py
101 lines (87 loc) · 3.91 KB
/
06_baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import matplotlib.pyplot as plt
from config import conf
import os, sys
import pandas as pns
from config import names as gs
import getopt
import matplotlib.gridspec as gridspec
from sklearn.metrics import f1_score, accuracy_score
import seaborn as sns
sns.set(style='whitegrid', color_codes=True)
sns.set_context('poster')
dark_color = sns.xkcd_rgb['charcoal grey']
light_color = sns.xkcd_rgb['cloudy blue']
max_n_feat = conf.max_n_feat
m_iter = conf.max_n_iter
featurelabels = gs.full_long_label_list
participant_ids = np.arange(0, conf.n_participants)
def plot_overview():
all_baselines.groupby(by=['trait', 'clf_name'])['F1'].mean().to_csv(conf.figure_folder +
'/figure1.csv')
print 'Figure1.csv written'
sns.set(font_scale=2.1)
plt.figure(figsize=(20, 10))
ax = plt.subplot(1,1,1)
sns.barplot(x='trait', y='F1', hue='clf_name', data=all_baselines, capsize=.05, errwidth=3,
linewidth=3, estimator=np.mean, edgecolor=dark_color,
palette={'our classifier': sns.xkcd_rgb['windows blue'],
'most frequent class': sns.xkcd_rgb['faded green'],
'random guess':sns.xkcd_rgb['greyish brown'],
'label permutation':sns.xkcd_rgb['dusky pink']
}
)
plt.plot([-0.5,6.5], [0.33, 0.33], c=dark_color, linestyle='--', linewidth=3, label='theoretical chance level')
handles, labels = ax.get_legend_handles_labels()
ax.legend([handles[1], handles[2], handles[3], handles[4], handles[0]], [labels[1], labels[2], labels[3], labels[4], labels[0]], fontsize=20)
plt.xlabel('')
plt.ylabel('F1 score', fontsize=20)
plt.ylim((0, 0.55))
filename = conf.figure_folder + '/figure1.pdf'
plt.savefig(filename, bbox_inches='tight')
plt.close()
print 'wrote', filename.split('/')[-1]
if __name__ == "__main__":
# collect F1 scores for classifiers on all data from a file that was written by evaluation_single_context.py
datapath = conf.get_result_folder(conf.annotation_all) + '/f1s.csv'
if not os.path.exists(datapath):
print 'could not find', datapath
print 'consider (re-)running evaluation_single_context.py'
sys.exit(1)
our_classifier = pns.read_csv(datapath)
our_classifier['clf_name'] = 'our classifier'
# baseline 1: guess the most frequent class from each training set that was written by train_baseline.py
datapath = conf.result_folder + '/most_frequ_class_baseline.csv'
if not os.path.exists(datapath):
print 'could not find', datapath
print 'consider (re-)running train_baseline.py'
sys.exit(1)
most_frequent_class_df = pns.read_csv(datapath)
most_frequent_class_df['clf_name'] = 'most frequent class'
# compute all other baselines ad hoc
collection = []
for trait in xrange(0, conf.n_traits):
# baseline 2: random guess
truth = np.genfromtxt(conf.binned_personality_file, skip_header=1, usecols=(trait+1,), delimiter=',')
for i in xrange(0, 100):
rand_guess = np.random.randint(1, 4, conf.n_participants)
f1 = f1_score(truth, rand_guess, average='macro')
collection.append([f1, conf.medium_traitlabels[trait], i, 'random guess'])
# baseline 3: label permutation test
# was computed using label_permutation_test.sh and written into results. ie. is just loaded here
for si in xrange(0, m_iter):
filename_rand = conf.get_result_filename(conf.annotation_all, trait, True, si, add_suffix=True)
if os.path.exists(filename_rand):
data = np.load(filename_rand)
pr = data['predictions']
dt = truth[pr > 0]
pr = pr[pr > 0]
f1 = f1_score(dt, pr, average='macro')
collection.append([f1, conf.medium_traitlabels[trait], si, 'label permutation'])
else:
print 'did not find', filename_rand
print 'consider (re-)running label_permutation_test.sh'
sys.exit(1)
collectiondf = pns.DataFrame(data=collection,columns=['F1','trait','iteration','clf_name'])
all_baselines = pns.concat([our_classifier, most_frequent_class_df, collectiondf])
plot_overview() # Figure 1