-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator_helper.py
123 lines (85 loc) · 2.82 KB
/
evaluator_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from abc import ABC, abstractmethod
from enum import Enum
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
import torch
from torch import nn
"""
Evaluator Helper Type
"""
class EvaluatorHelperType(Enum):
BINARY = 0
MULTIPLE = 1
"""
Abstract Evaluator Helper
"""
class AbstractEvaluatorHelper(ABC):
@abstractmethod
def mean_nll(self, logits, y):
raise Exception("Abstract method should be implemented")
@abstractmethod
def mean_accuracy(self, logits, y):
raise Exception("Abstract method should be implemented")
@abstractmethod
def mean_roc_auc(self, logits, y):
raise Exception()
@abstractmethod
def mean_pr_auc(self, logits, y):
raise Exception()
"""
Binary Classification
"""
class BinaryClassificationEvaluatorHelper(AbstractEvaluatorHelper):
def mean_nll(self, logits, y):
critetion = nn.BCELoss()
return critetion(logits, y)
def mean_accuracy(self, logits, y):
preds = (logits > 0.5).float()
return ((preds - y).abs() < 1e-2).float().mean()
def mean_roc_auc(self, logits, y):
preds = (logits > 0.5).float()
y = y.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
return roc_auc_score(y, preds)
def mean_pr_auc(self, logits, y):
preds = (logits > 0.5).float()
y = y.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
precision, recall, _ = precision_recall_curve(y, preds)
return auc(recall, precision)
"""
Multi Classification
"""
class MultiClassificationEvaluatorHelper(AbstractEvaluatorHelper):
def mean_nll(self, logits, y):
critetion = nn.CrossEntropyLoss()
return critetion(logits, y)
def mean_accuracy(self, logits, y):
_, preds = torch.max(logits, 1)
correct = (preds == y).sum().item()
total = y.size(0)
return correct / total
def mean_roc_auc(self, logits, y):
return None
def mean_pr_auc(self, logits, y):
return None
"""
Evaluator Helper Factory
"""
class EvaluatorHelperFactory:
__binary = None
__multi = None
@staticmethod
def get_evaluator(type):
if type == EvaluatorHelperType.BINARY:
if EvaluatorHelperFactory.__binary is None:
EvaluatorHelperFactory.__binary = BinaryClassificationEvaluatorHelper()
return EvaluatorHelperFactory.__binary
elif type == EvaluatorHelperType.MULTIPLE:
if EvaluatorHelperFactory.__multi is None:
EvaluatorHelperFactory.__multi = MultiClassificationEvaluatorHelper()
return EvaluatorHelperFactory.__multi
else:
raise Exception(
"Unsupported evaluator helper type: {}".format(type))