This repository has been archived by the owner on Apr 3, 2020. It is now read-only.
forked from salesforce/awd-lstm-lm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
65 lines (53 loc) · 1.99 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import data
import pprint
import torch
from torch.autograd import Variable
parser = argparse.ArgumentParser(description='Sentence Probability Evaluation')
parser.add_argument('--data', type=str, default='./data/penn',
help='location of the data corpus')
parser.add_argument('--checkpoint', type=str, default='./model.pt',
help='model checkpoint to use')
parser.add_argument('--model', type=str, default='LSTM',
help='type of recurrent net (LSTM, QRNN)')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--inf', type=str, default='./eval.txt',
help='file to evaluate')
args = parser.parse_args()
with open(args.checkpoint, 'rb') as f:
model = torch.load(f, map_location=lambda storage, loc: storage)[0]
model.eval()
if args.model == 'QRNN':
model.reset()
if args.cuda:
model.cuda()
else:
model.cpu()
corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1)
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True)
if args.cuda:
input.data = input.data.cuda()
all_words = []
# Tokenize file content
with open(args.inf, 'r') as f:
# TODO - maybe we should count the words first so this can be dynamically allocated,
# currently this limits the file length to 4096 words
ids = torch.LongTensor(4096)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = corpus.dictionary.word2idx[word]
token += 1
all_words.append(word)
probs = []
for word_idx in ids[:len(all_words)]:
input.data.fill_(word_idx)
output, hidden = model(input, hidden)
word_weights = model.decoder(output).squeeze().cpu()
softmax_output_flat = torch.nn.functional.softmax(word_weights)
probs.append(softmax_output_flat[word_idx])
pprint.pprint(list(zip(["%.4f" % float(p) for p in probs], all_words[1:])))