-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
80 lines (70 loc) · 2.72 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from utils import create_directory, get_time_prefix
from trainer import forward_pass
import numpy as np
import torch
import torch.nn as nn
def write_score(fname, val):
f = open(fname, 'w')
f.write("Accuracy : " + str(val))
f.close()
def get_accuracy(logits, targets):
corr = 0
tot = len(logits)
for i in range(len(logits)):
if logits[i] == targets[i]:
corr += 1
score = corr / tot
return score
def save_results(logits, targets, imgs, ques, corr_fname, incorr_fname):
fcorr = open(corr_fname, 'w')
fincorr = open(incorr_fname, 'w')
for i in range(len(logits)):
if logits[i] == targets[i]:
fcorr.write(str(imgs[i]) + ',' + ques[i] + ',' + str(logits[i]) + '\n')
else:
fincorr.write(str(imgs[i]) + ',' + ques[i] + ',' + str(logits[i]) + '\n')
fcorr.close()
fincorr.close()
def test(model, config, dataloader, vocab, test=True, gen=False):
eval_path = './evaluations/'
create_directory(eval_path)
model_name = config['name']
ts = get_time_prefix()
eval_fname = eval_path + ts + model_name + '.txt'
device = torch.device("cuda:0" if config['GPU'] is True else "cpu")
model.to(device)
if config['GPU'] is True:
model = nn.DataParallel(model, device_ids=config['GPU_Ids'])
model.eval()
for param in model.parameters():
param.requires_grad = False
logits = []
targets = []
imgs = []
ques = []
if gen is True:
corr_fname = eval_path + ts + model_name + '_corr.csv'
incorr_fname = eval_path + ts + model_name + '_incorr.csv'
for images, questions, answers, image_ids, questions_orig in dataloader:
out = forward_pass(model, config, images, questions, vocab, train=False)
preds = torch.argmax(out, dim=1).tolist()
logits = logits + preds
targets = targets + answers.tolist()
imgs = imgs + image_ids.tolist()
ques = ques + questions_orig
save_results(logits, targets, imgs, ques, corr_fname, incorr_fname)
score = get_accuracy(logits, targets)
print("Accuracy : ", score)
write_score(eval_fname, score)
elif test is True:
for images, questions in dataloader:
out = forward_pass(model, config, images, questions, vocab, train=False)
else:
for images, questions, answers in dataloader:
out = forward_pass(model, config, images, questions, vocab, train=False)
preds = torch.argmax(out, dim=1).tolist()
logits = logits + preds
targets = targets + answers.tolist()
score = get_accuracy(logits, targets)
print("Accuracy : ", score)
write_score(eval_fname, score)