-
Notifications
You must be signed in to change notification settings - Fork 139
/
main.py
187 lines (169 loc) · 7.43 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from __future__ import division
from __future__ import print_function
import os
import random
import logging
import torch
import torch.nn as nn
import torch.optim as optim
# IMPORT CONSTANTS
from treelstm import Constants
# NEURAL NETWORK MODULES/LAYERS
from treelstm import SimilarityTreeLSTM
# DATA HANDLING CLASSES
from treelstm import Vocab
# DATASET CLASS FOR SICK DATASET
from treelstm import SICKDataset
# METRICS CLASS FOR EVALUATION
from treelstm import Metrics
# UTILITY FUNCTIONS
from treelstm import utils
# TRAIN AND TEST HELPER FUNCTIONS
from treelstm import Trainer
# CONFIG PARSER
from config import parse_args
# MAIN BLOCK
def main():
global args
args = parse_args()
# global logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s")
# file logger
fh = logging.FileHandler(os.path.join(args.save, args.expname)+'.log', mode='w')
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)
# console logger
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
# argument validation
args.cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda:0" if args.cuda else "cpu")
if args.sparse and args.wd != 0:
logger.error('Sparsity and weight decay are incompatible, pick one!')
exit()
logger.debug(args)
torch.manual_seed(args.seed)
random.seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.benchmark = True
if not os.path.exists(args.save):
os.makedirs(args.save)
train_dir = os.path.join(args.data, 'train/')
dev_dir = os.path.join(args.data, 'dev/')
test_dir = os.path.join(args.data, 'test/')
# write unique words from all token files
sick_vocab_file = os.path.join(args.data, 'sick.vocab')
if not os.path.isfile(sick_vocab_file):
token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]]
token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]]
token_files = token_files_a + token_files_b
sick_vocab_file = os.path.join(args.data, 'sick.vocab')
utils.build_vocab(token_files, sick_vocab_file)
# get vocab object from vocab file previously written
vocab = Vocab(filename=sick_vocab_file,
data=[Constants.PAD_WORD, Constants.UNK_WORD,
Constants.BOS_WORD, Constants.EOS_WORD])
logger.debug('==> SICK vocabulary size : %d ' % vocab.size())
# load SICK dataset splits
train_file = os.path.join(args.data, 'sick_train.pth')
if os.path.isfile(train_file):
train_dataset = torch.load(train_file)
else:
train_dataset = SICKDataset(train_dir, vocab, args.num_classes)
torch.save(train_dataset, train_file)
logger.debug('==> Size of train data : %d ' % len(train_dataset))
dev_file = os.path.join(args.data, 'sick_dev.pth')
if os.path.isfile(dev_file):
dev_dataset = torch.load(dev_file)
else:
dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes)
torch.save(dev_dataset, dev_file)
logger.debug('==> Size of dev data : %d ' % len(dev_dataset))
test_file = os.path.join(args.data, 'sick_test.pth')
if os.path.isfile(test_file):
test_dataset = torch.load(test_file)
else:
test_dataset = SICKDataset(test_dir, vocab, args.num_classes)
torch.save(test_dataset, test_file)
logger.debug('==> Size of test data : %d ' % len(test_dataset))
# initialize model, criterion/loss_function, optimizer
model = SimilarityTreeLSTM(
vocab.size(),
args.input_dim,
args.mem_dim,
args.hidden_dim,
args.num_classes,
args.sparse,
args.freeze_embed)
criterion = nn.KLDivLoss()
# for words common to dataset vocab and GLOVE, use GLOVE vectors
# for other words in dataset vocab, use random normal vectors
emb_file = os.path.join(args.data, 'sick_embed.pth')
if os.path.isfile(emb_file):
emb = torch.load(emb_file)
else:
# load glove embeddings and vocab
glove_vocab, glove_emb = utils.load_word_vectors(
os.path.join(args.glove, 'glove.840B.300d'))
logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
emb = torch.zeros(vocab.size(), glove_emb.size(1), dtype=torch.float, device=device)
emb.normal_(0, 0.05)
# zero out the embeddings for padding and other special words if they are absent in vocab
for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD,
Constants.BOS_WORD, Constants.EOS_WORD]):
emb[idx].zero_()
for word in vocab.labelToIdx.keys():
if glove_vocab.getIndex(word):
emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)]
torch.save(emb, emb_file)
# plug these into embedding matrix inside model
model.emb.weight.data.copy_(emb)
model.to(device), criterion.to(device)
if args.optim == 'adam':
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()), lr=args.lr, weight_decay=args.wd)
elif args.optim == 'adagrad':
optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
model.parameters()), lr=args.lr, weight_decay=args.wd)
elif args.optim == 'sgd':
optimizer = optim.SGD(filter(lambda p: p.requires_grad,
model.parameters()), lr=args.lr, weight_decay=args.wd)
metrics = Metrics(args.num_classes)
# create trainer object for training and testing
trainer = Trainer(args, model, criterion, optimizer, device)
best = -float('inf')
for epoch in range(args.epochs):
train_loss = trainer.train(train_dataset)
train_loss, train_pred = trainer.test(train_dataset)
dev_loss, dev_pred = trainer.test(dev_dataset)
test_loss, test_pred = trainer.test(test_dataset)
train_pearson = metrics.pearson(train_pred, train_dataset.labels)
train_mse = metrics.mse(train_pred, train_dataset.labels)
logger.info('==> Epoch {}, Train \tLoss: {}\tPearson: {}\tMSE: {}'.format(
epoch, train_loss, train_pearson, train_mse))
dev_pearson = metrics.pearson(dev_pred, dev_dataset.labels)
dev_mse = metrics.mse(dev_pred, dev_dataset.labels)
logger.info('==> Epoch {}, Dev \tLoss: {}\tPearson: {}\tMSE: {}'.format(
epoch, dev_loss, dev_pearson, dev_mse))
test_pearson = metrics.pearson(test_pred, test_dataset.labels)
test_mse = metrics.mse(test_pred, test_dataset.labels)
logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format(
epoch, test_loss, test_pearson, test_mse))
if best < test_pearson:
best = test_pearson
checkpoint = {
'model': trainer.model.state_dict(),
'optim': trainer.optimizer,
'pearson': test_pearson, 'mse': test_mse,
'args': args, 'epoch': epoch
}
logger.debug('==> New optimum found, checkpointing everything now...')
torch.save(checkpoint, '%s.pt' % os.path.join(args.save, args.expname))
if __name__ == "__main__":
main()