-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_cluener.py
121 lines (94 loc) · 4.74 KB
/
train_cluener.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from transformers import BertTokenizerFast
from transformers import AdamW
from tqdm import tqdm
import utils.utils_cluener as C
import utils.utils_Generic as G
from utils.utils_training_testing import eval_cluener
from nets.Bert_BiLSTM_CRF_Combined import CombinedNER
if __name__ == '__main__':
# ----------------------------------------------------#
# data_path: Path to the training set
# test_path: Path to the test set
# val_ratio: Ratio of validation set
# pretrained_model_name: Which pretrained bert model to use
# use_bilstm: Use BiLSTM or not
# test_performance: Use test set to test the performance
# ----------------------------------------------------#
data_path = 'dataset/cluener/train.json'
test_path = 'dataset/cluener/dev.json'
val_ratio = 0.15
pretrained_model_name = 'All_Bert_Pretrained_Models/bert-base-chinese'
use_bilstm = True
test_performance = False
# ----------------------------------------------------#
# Training parameters
# epoch_num Epoch number
# batch_size Batch size
# Important: Theoretically, the learning rate for pretrained bert model,
# and the learning rate for LSTM should be different!!!
# lr_bert Learning rate for bert model
# lr_other Learning rate for other layers
# ----------------------------------------------------#
epoch_num = 1
batch_size = 2
lr_bert = 1e-5
lr_other = 1e-4
# ----------------------------------------------------#
# Read in the data and get the dataloaders
# ----------------------------------------------------#
train_data, test_data, val_data, label_list, categories = C.get_train_test_val(data_path, test_path, val_ratio)
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)
train_loader = G.get_dataloader(train_data, tokenizer, categories, mode='Train')
test_loader = G.get_dataloader(test_data, tokenizer, categories, mode='Test')
val_loader = G.get_dataloader(val_data, tokenizer, categories, mode='Val')
# ----------------------------------------------------#
# Get the model and put it on GPU
# ----------------------------------------------------#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CombinedNER(len(label_list), bert_type=pretrained_model_name, need_rnn=use_bilstm)
model.to(device)
# ----------------------------------------------------#
# Set up the optimizer
# We don't need the loss function here -> CRF has already done this part
# ----------------------------------------------------#
optimizer = AdamW([
{'params': model.bert.parameters(), 'lr': lr_bert},
{'params': model.rnn.parameters(), 'lr': lr_other, 'weight_decay': 1e-4} if model.need_rnn else {'params': []},
{'params': model.classifier.parameters(), 'lr': lr_other},
{'params': model.crf.parameters(), 'lr': lr_other}
])
# ----------------------------------------------------#
# Start training
# ----------------------------------------------------#
print('\nStart training!!!\n')
for epoch in range(epoch_num):
total_batches = len(train_loader)
with tqdm(total=total_batches, desc=f'Epoch {epoch + 1}/{epoch_num}', unit='batch') as pbar:
for data in train_loader:
model.train()
tokenized_inputs, targets = data
tokenized_inputs, targets = tokenized_inputs.to(device), targets.to(device)
loss = model(tokenized_inputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(1)
pbar.set_postfix(loss=loss.item())
with torch.no_grad():
model.eval()
precision, recall, f1 = eval_cluener(val_loader, model, device, categories)
print(f'Epoch: {epoch + 1:02d}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.3f}')
if (epoch + 1) % 5 == 0:
sub_path = int(f1 * 1000)
save_path = f'logs/model_f1_{sub_path}.pth'
torch.save(model.state_dict(), save_path)
print('\nFinished Training!!!\n')
# ----------------------------------------------------#
# If you want to test the model performance after training
# ----------------------------------------------------#
if test_performance:
with torch.no_grad():
model.eval()
precision, recall, f1 = eval_cluener(test_loader, model, device, categories)
print(f'On the test set:\n Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.3f}')