-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.py
51 lines (35 loc) · 1.49 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
class Evaluator:
def __init__(self, helper):
self.__helper = helper
def evaluate_model(self, model, test_loader, test_batch_size):
with torch.no_grad():
# Set mode to evaluate model
model.eval()
# Start evaluating model
final_loss = 0
final_acc = 0
final_roc = []
final_pr = []
for (images, labels) in test_loader:
_, logits = model(images)
loss = self.__helper.mean_nll(logits, labels)
acc = self.__helper.mean_accuracy(logits, labels)
final_loss += loss
final_acc += acc
if len(labels) == test_batch_size and 1.0 in torch.unique(labels) and 0.0 in torch.unique(labels):
roc = self.__helper.mean_roc_auc(logits, labels)
pr = self.__helper.mean_pr_auc(logits, labels)
if roc is not None:
final_roc.append(roc)
if pr is not None:
final_pr.append(pr)
test_loss = final_loss / len(test_loader)
test_acc = final_acc / len(test_loader)
test_roc = None
if len(final_roc) > 0:
test_roc = sum(final_roc) / len(final_roc)
test_pr = None
if len(final_pr) > 0:
test_pr = sum(final_pr) / len(final_pr)
return test_loss, test_acc, test_roc, test_pr