-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ori.py
100 lines (73 loc) · 2.92 KB
/
test_ori.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import torch
from torch import nn
from torch.autograd import Variable
from model.BIMPM_new import BIMPM
from model.utils import SNLI, Quora
def test(model, args, data, mode='test'):
if mode == 'dev':
iterator = iter(data.dev_iter)
else:
iterator = iter(data.test_iter)
criterion = nn.CrossEntropyLoss()
model.eval()
acc, loss, size = 0, 0, 0
for batch in iterator:
if args.data_type == 'SNLI':
s1, s2 = 'premise', 'hypothesis'
else:
s1, s2 = 'q1', 'q2'
s1, s2 = getattr(batch, s1), getattr(batch, s2)
kwargs = {'p': s1, 'h': s2}
if args.use_char_emb:
char_p = Variable(torch.LongTensor(data.characterize(s1)))
char_h = Variable(torch.LongTensor(data.characterize(s2)))
if args.gpu > -1:
char_p = char_p.cuda(args.gpu)
char_h = char_h.cuda(args.gpu)
kwargs['char_p'] = char_p
kwargs['char_h'] = char_h
pred = model(**kwargs)
batch_loss = criterion(pred, batch.label)
loss += batch_loss.item()
_, pred = pred.max(dim=1)
acc += (pred == batch.label).sum().float()
size += len(pred)
acc /= size
return loss, acc.item()
def load_model(args, data):
model = BIMPM(args, data)
model.load_state_dict(torch.load(args.model_path))
if args.gpu > -1:
model.cuda(args.gpu)
return model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=64, type=int)
parser.add_argument('--char-dim', default=20, type=int)
parser.add_argument('--char-hidden-size', default=50, type=int)
parser.add_argument('--dropout', default=0.1, type=float)
parser.add_argument('--data-type', default='SNLI', help='available: SNLI or Quora')
parser.add_argument('--epoch', default=10, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--hidden-size', default=100, type=int)
parser.add_argument('--learning-rate', default=0.001, type=float)
parser.add_argument('--num-perspective', default=20, type=int)
parser.add_argument('--use-char-emb', default=True, action='store_true')
parser.add_argument('--word-dim', default=300, type=int)
parser.add_argument('--model-path', required=True)
args = parser.parse_args()
if args.data_type == 'SNLI':
print('loading SNLI data...')
data = SNLI(args)
elif args.data_type == 'Quora':
print('loading Quora data...')
data = Quora(args)
setattr(args, 'char_vocab_size', len(data.char_vocab))
setattr(args, 'word_vocab_size', len(data.TEXT.vocab))
setattr(args, 'class_size', len(data.LABEL.vocab))
setattr(args, 'max_word_len', data.max_word_len)
print('loading model...')
model = load_model(args, data)
_, acc = test(model, args, data)
print(f'test acc: {acc:.3f}')