forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_biLSTM.py
105 lines (84 loc) · 3.05 KB
/
train_biLSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import argparse
import torch
import torch.nn as nn
from torch import cuda
import onmt
import onmt.Models
import onmt.ModelConstructor
import onmt.modules
from onmt.Utils import aeq, use_gpu
import opts
parser = argparse.ArgumentParser(description='train_biLSTM.py')
opts.model_opts(parser)
opts.train_opts(parser)
opt = parser.parse_args()
if torch.cuda.is_available() and not opt.gpuid:
print("WARNING: You have a CUDA device, should run with -gpuid 0")
if opt.gpuid:
cuda.set_device(opt.gpuid[0])
if opt.seed > 0:
torch.cuda.manual_seed(opt.seed)
def load_fields(train, checkpoint):
fields = onmt.IO.ONMTDataset.load_fields(
torch.load(opt.data + '.vocab.pt'))
fields = dict([(k, f) for (k, f) in fields.items()
if k in train.examples[0].__dict__])
train.fields = fields
if opt.train_from:
print('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = onmt.IO.ONMTDataset.load_fields(checkpoint['vocab'])
return fields
def build_model(model_opt, opt, fields, checkpoint):
print('Building model...')
model = onmt.ModelConstructor.make_base_model(model_opt, fields,
use_gpu(opt), checkpoint)
if len(opt.gpuid) > 1:
print('Multi gpu training ', opt.gpuid)
model = nn.DataParallel(model, device_ids=opt.gpuid, dim=1)
print(model)
return model
def build_optim(model, checkpoint):
if opt.train_from:
print('Loading optimizer from checkpoint.')
optim = checkpoint['optim']
optim.optimizer.load_state_dict(
checkpoint['optim'].optimizer.state_dict())
else:
# what members of opt does Optim need?
optim = onmt.Optim(
opt.optim, opt.learning_rate, opt.max_grad_norm,
lr_decay=opt.learning_rate_decay,
start_decay_at=opt.start_decay_at,
opt=opt
)
optim.set_parameters(model.parameters())
return optim
def main():
# Load Monolingual Training Data
print("Loading trian data from '%s'") % opt.data
train = torch.load(opt.data + '.train.pt')
print(' * number of training sentences: %d' % len(train))
print(' * maximum batch size: %d' % opt.batch_size)
if opt.train_from:
print('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
model_opt = checkpoint['opt']
# I don't like reassigning attributes of opt: it's not clear
opt.start_epoch = checkpoint['epoch'] + 1
else:
checkpoint = None
model_opt = opt
# To Add Field to Training data
fields = load_fields(train, checkpoint)
# Build model.
model = build_model(model_opt, opt, fields, checkpoint)
tally_parameters(model)
check_save_model_path()
# Build optimizer.
optim = build_optim(model, checkpoint)
# Do training.
train_model(model, train, valid, fields, optim)
if __name__ == '__main__':
main()