-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdecode.py
136 lines (106 loc) · 4.34 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
pytorch-dl
Created by raj at 6:59 AM, 7/30/20
"""
import os
import sys
import time
from math import inf
from torch import nn
from dataset.iwslt_data import rebatch, rebatch_onmt, SimpleLossCompute, NoamOpt, LabelSmoothing
from models.decoding import beam_search, batched_beam_search, greedy_decode
from models.transformer import TransformerEncoderDecoder
from models.utils.model_utils import load_model_state, save_state, get_perplexity
"""Train models."""
import torch
import onmt.opts as opts
from onmt.utils.misc import set_random_seed
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import build_dataset_iter
def decode(opt):
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)
set_random_seed(opt.seed, False)
# For decoding we dont have yet batch beam search
opt.valid_batch_size = 1
model_dir = opt.save_model
try:
os.makedirs(model_dir)
except OSError:
pass
start_steps, model, fields = load_model_state(os.path.join(model_dir, 'checkpoints_best.pt'), opts,
data_parallel=False)
model.eval()
src_vocab = fields['src'].base_field.vocab
trg_vocab = fields['tgt'].base_field.vocab
pad_idx = src_vocab.stoi["<blank>"]
unk_idx = src_vocab.stoi["<unk>"]
start_symbol = trg_vocab.stoi["<s>"]
if start_symbol == unk_idx:
if opt.tgt_lang_id:
start_symbol = trg_vocab.stoi["<" + opt.tgt_lang_id + ">"]
else:
raise AssertionError("For mBart fine-tuned model, --tgt_lang_id is necessary to set. eg DE EN etc.")
valid_iter = build_dataset_iter(
"valid", fields, opt, is_train=False)
cuda_condition = torch.cuda.is_available() and opt.gpu_ranks
device = torch.device("cuda:0" if cuda_condition else "cpu")
if cuda_condition:
model.cuda()
with torch.no_grad():
translated = list()
reference = list()
start = time.time()
for k, batch in enumerate(rebatch_onmt(pad_idx, b, device=device) for b in valid_iter):
print('Processing: {0}'.format(k))
# out = greedy_decode(model, batch.src, batch.src_mask, start_symbol=start_symbol)
# out = beam_search(model, batch.src, batch.src_mask,
# start_symbol=start_symbol, pad_symbol=pad_idx,
# max=batch.ntokens + 10)
out = batched_beam_search(model, batch.src, batch.src_mask,
start_symbol=start_symbol, pad_symbol=pad_idx,
max=batch.ntokens + 10)
# print("Source:", end="\t")
# for i in range(1, batch.src.size(1)):
# sym = SRC.vocab.itos[batch.src.data[0, i]]
# if sym == "<eos>": break
# print(sym, end=" ")
# print()
# print("Translation:", end="\t")
transl = list()
start_idx = 0 # for greedy decoding the start index should be 1 that will exclude the <sos> symbol
for i in range(start_idx, out.size(1)):
sym = trg_vocab.itos[out[0, i]]
if sym == "</s>": break
transl.append(sym)
translated.append(' '.join(transl))
# print()
# print("Target:", end="\t")
ref = list()
for i in range(1, batch.trg.size(1)):
sym = trg_vocab.itos[batch.trg.data[0, i]]
if sym == "</s>": break
ref.append(sym)
reference.append(" ".join(ref))
# if k == 1:
# break
with open('valid-beam-decode-test.de-en.en', 'w', encoding='utf8') as outfile:
outfile.write('\n'.join(translated))
with open('valid-ref.de-en.en', 'w', encoding='utf-8') as outfile:
outfile.write('\n'.join(reference))
print('Time elapsed:{}'.format(time.time()- start))
def _get_parser():
parser = ArgumentParser(description='train.py')
opts.config_opts(parser)
opts.model_opts(parser)
opts.train_opts(parser)
return parser
def main():
parser = _get_parser()
opt = parser.parse_args()
decode(opt)
if __name__ == "__main__":
main()