-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathmain.py
155 lines (125 loc) · 5.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import json
import time
import argparse
import torch
from typing import Dict, List, Tuple, Set, Optional
from prefetch_generator import BackgroundGenerator
from tqdm import tqdm
from torch.optim import Adam, SGD
from pytorch_transformers import AdamW, WarmupLinearSchedule
from lib.preprocessings import Chinese_selection_preprocessing, Conll_selection_preprocessing, Conll_bert_preprocessing
from lib.dataloaders import Selection_Dataset, Selection_loader
from lib.metrics import F1_triplet, F1_ner
from lib.models import MultiHeadSelection
from lib.config import Hyper
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name',
'-e',
type=str,
default='conll_bert_re',
help='experiments/exp_name.json')
parser.add_argument('--mode',
'-m',
type=str,
default='preprocessing',
help='preprocessing|train|evaluation')
args = parser.parse_args()
class Runner(object):
def __init__(self, exp_name: str):
self.exp_name = exp_name
self.model_dir = 'saved_models'
self.hyper = Hyper(os.path.join('experiments',
self.exp_name + '.json'))
self.gpu = self.hyper.gpu
self.preprocessor = None
self.triplet_metrics = F1_triplet()
self.ner_metrics = F1_ner()
self.optimizer = None
self.model = None
def _optimizer(self, name, model):
m = {
'adam': Adam(model.parameters()),
'sgd': SGD(model.parameters(), lr=0.5),
'adamw': AdamW(model.parameters())
}
return m[name]
def _init_model(self):
self.model = MultiHeadSelection(self.hyper).cuda(self.gpu)
def preprocessing(self):
if self.exp_name == 'conll_selection_re':
self.preprocessor = Conll_selection_preprocessing(self.hyper)
elif self.exp_name == 'chinese_selection_re':
self.preprocessor = Chinese_selection_preprocessing(self.hyper)
elif self.exp_name == 'conll_bert_re':
self.preprocessor = Conll_bert_preprocessing(self.hyper)
self.preprocessor.gen_relation_vocab()
self.preprocessor.gen_all_data()
self.preprocessor.gen_vocab(min_freq=1)
# for ner only
self.preprocessor.gen_bio_vocab()
def run(self, mode: str):
if mode == 'preprocessing':
self.preprocessing()
elif mode == 'train':
self._init_model()
self.optimizer = self._optimizer(self.hyper.optimizer, self.model)
self.train()
elif mode == 'evaluation':
self._init_model()
self.load_model(epoch=self.hyper.evaluation_epoch)
self.evaluation()
else:
raise ValueError('invalid mode')
def load_model(self, epoch: int):
self.model.load_state_dict(
torch.load(
os.path.join(self.model_dir,
self.exp_name + '_' + str(epoch))))
def save_model(self, epoch: int):
if not os.path.exists(self.model_dir):
os.mkdir(self.model_dir)
torch.save(
self.model.state_dict(),
os.path.join(self.model_dir, self.exp_name + '_' + str(epoch)))
def evaluation(self):
dev_set = Selection_Dataset(self.hyper, self.hyper.dev)
loader = Selection_loader(dev_set, batch_size=self.hyper.eval_batch, pin_memory=True)
self.triplet_metrics.reset()
self.model.eval()
pbar = tqdm(enumerate(BackgroundGenerator(loader)), total=len(loader))
with torch.no_grad():
for batch_ndx, sample in pbar:
output = self.model(sample, is_train=False)
self.triplet_metrics(output['selection_triplets'], output['spo_gold'])
self.ner_metrics(output['gold_tags'], output['decoded_tag'])
triplet_result = self.triplet_metrics.get_metric()
ner_result = self.ner_metrics.get_metric()
print('Triplets-> ' + ', '.join([
"%s: %.4f" % (name[0], value)
for name, value in triplet_result.items() if not name.startswith("_")
]) + ' ||' + 'NER->' + ', '.join([
"%s: %.4f" % (name[0], value)
for name, value in ner_result.items() if not name.startswith("_")
]))
def train(self):
train_set = Selection_Dataset(self.hyper, self.hyper.train)
loader = Selection_loader(train_set, batch_size=self.hyper.train_batch, pin_memory=True)
for epoch in range(self.hyper.epoch_num):
self.model.train()
pbar = tqdm(enumerate(BackgroundGenerator(loader)),
total=len(loader))
for batch_idx, sample in pbar:
self.optimizer.zero_grad()
output = self.model(sample, is_train=True)
loss = output['loss']
loss.backward()
self.optimizer.step()
pbar.set_description(output['description'](
epoch, self.hyper.epoch_num))
self.save_model(epoch)
if epoch % self.hyper.print_epoch == 0 and epoch > 3:
self.evaluation()
if __name__ == "__main__":
config = Runner(exp_name=args.exp_name)
config.run(mode=args.mode)