-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval_metrics.py
66 lines (56 loc) · 2.55 KB
/
eval_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#Evaluate Best-Performing Model on Validation or Test Set
from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
from scipy.special import softmax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def get_metrics(raw_pred, y_true, output_dir, suffix='', plot=True):
#Calculate Performance Metrics
y_pred = np.argmax(raw_pred, axis=1)
n_pred_pos = sum(y_pred)
actual_pos = sum(y_true)
ac = accuracy_score(y_true, y_pred) #Accuracy
sm_scores = softmax(raw_pred, axis=1) #Compute softmax on a per-row basis to normalize raw predictions
y_score = sm_scores[:,1]
au_roc = roc_auc_score(y_true, y_score) #AU-ROC
au_prc = average_precision_score(y_true, y_score) #AU-PRC
#Save raw and normalized scores
raw_df = pd.DataFrame({'raw_pred':raw_pred[:,1],
'transformed_pred':y_score,
'y_pred':y_pred,
'y_true':y_true})
raw_df.to_csv(output_dir+'raw_predictions_'+suffix+'.csv', index=False)
#Save final evaluation data
eval_df = pd.DataFrame({'Accuracy':[ac],
'AUROC':[au_roc],
'AUPRC':[au_prc],
'n_pred_pos':[n_pred_pos],
'actual_pos':[actual_pos]})
eval_df.to_csv(output_dir+'eval_metrics_'+suffix+'.csv', index=False)
#Generate Performance Plots
if plot:
#Plot ROC Curve
fpr, tpr, threshold = roc_curve(y_true, y_score)
roc_auc = roc_auc_score(y_true, y_score)
plt.title('ROC Curve')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.savefig(output_dir+'roc_curve_'+suffix+'.png', dpi=1200, facecolor='w')
plt.close()
#Plot PR Curve
p, r, threshold = precision_recall_curve(y_true, y_score)
auc = average_precision_score(y_true, y_score)
plt.title('Precision-Recall Curve')
plt.ylabel('Precision')
plt.xlabel('Recall')
plt.plot(r,p, 'b', label = 'AUC = %0.2f' % auc)
plt.legend(loc = 'lower right')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.savefig(output_dir+'pr_curve_'+suffix+'.png', bbox_inches='tight', dpi=1200, facecolor='w')
plt.close()
return ac, au_roc, au_prc, n_pred_pos, actual_pos