-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_t5.py
210 lines (171 loc) · 9.48 KB
/
train_t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import time
import os
import torch
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from dataloader import MainDataLoader
from nlgevaluation import compute_bleu
from transformers.optimization import AdamW, get_scheduler
from transformers.trainer_pt_utils import get_parameter_names
from models import ERGModel, ERGMainModel
def configure_dataloaders(batch_size):
"Prepare dataloaders"
train_loader = MainDataLoader("data/empathetic_dialogues/train_dpr.csv", batch_size, shuffle=True)
valid_loader = MainDataLoader("data/empathetic_dialogues/valid_dpr.csv", batch_size, shuffle=False)
test_loader = MainDataLoader("data/empathetic_dialogues/test_dpr.csv", batch_size, shuffle=False)
return train_loader, valid_loader, test_loader
def configure_optimizer(model, args):
"Prepare optimizer"
decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
optimizer_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
"lr": args.lr
}
optimizer = AdamW(optimizer_grouped_parameters, **optimizer_kwargs)
return optimizer
def configure_scheduler(optimizer, num_training_steps, args):
"Prepare scheduler"
warmup_steps = (
args.warmup_steps
if args.warmup_steps > 0
else math.ceil(num_training_steps * args.warmup_ratio)
)
lr_scheduler = get_scheduler(
args.lr_scheduler_type,
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
)
return lr_scheduler
def train_or_eval_model(model, dataloader, optimizer=None, train=False, main_loss_w=1, empathy_loss_w=1, sentiment_loss_w=1):
losses = []
assert not train or optimizer!=None
if train:
model.train()
else:
model.eval()
for conv_id, emotion, context, response, exemplars, empathy1_labels, empathy2_labels, empathy3_labels, sentiment in tqdm(dataloader, leave=False):
if train:
optimizer.zero_grad()
out, empathy1_preds, empathy2_preds, empathy3_preds, sentiment_preds = model(context, response, exemplars=exemplars)
empathy1_labels = torch.tensor(empathy1_labels).cuda()
empathy2_labels = torch.tensor(empathy2_labels).cuda()
empathy3_labels = torch.tensor(empathy3_labels).cuda()
sentiment_labels = torch.tensor(sentiment).cuda()
loss = out.loss
empathy1_loss = empathy_loss_function(empathy1_preds, empathy1_labels)
empathy2_loss = empathy_loss_function(empathy2_preds, empathy2_labels)
empathy3_loss = empathy_loss_function(empathy3_preds, empathy3_labels)
total_empathy_loss = empathy1_loss + empathy2_loss + empathy3_loss
sentiment_loss = sentiment_loss_function(sentiment_preds, sentiment_labels)
total_loss = main_loss_w * loss + empathy_loss_w * total_empathy_loss + sentiment_loss_w * sentiment_loss
if train:
total_loss.backward()
optimizer.step()
# losses.append(total_loss.item())
losses.append(loss.item()) # Return the generative seq2seq loss.
avg_loss = round(np.mean(losses), 4)
return avg_loss
def test_model(model, dataloader, mode):
references, hypothesis, utt_ids = [], [], []
for conv_id, emotion, context, response, exemplars, empathy1_labels, empathy2_labels, empathy3_labels, sentiment in tqdm(dataloader, leave=False):
ref = [[item] for item in response]
hyp = model.erg_model.generate(context, exemplars=exemplars, mode=mode)
references += ref
hypothesis += hyp
utt_ids += [[conv, len(con) + 1, "\n".join(con)] for conv, con in zip(conv_id, context)]
scores = compute_bleu(references, hypothesis)
bleu1 = round(100*scores["Bleu_1"], 2)
bleu2 = round(100*scores["Bleu_2"], 2)
return bleu1, bleu2, references, hypothesis, utt_ids
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--weight-decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam-epsilon", default=1e-8, type=float, help="Epsilon for AdamW optimizer.")
parser.add_argument("--adam-beta1", default=0.9, type=float, help="beta1 for AdamW optimizer.")
parser.add_argument("--adam-beta2", default=0.999, type=float, help="beta2 for AdamW optimizer.")
parser.add_argument("--lr-scheduler-type", default="linear")
parser.add_argument("--warmup-steps", type=int, default=0, help="Steps used for a linear warmup from 0 to lr.")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="Ratio of total training steps used for a linear warmup from 0 to lr.")
parser.add_argument("--src-len", type=int, default=200, help="Max source length.")
parser.add_argument("--tgt-len", type=int, default=50, help="Max target length.")
parser.add_argument("--batch-size", type=int, default=8, help="Batch size.")
parser.add_argument("--epochs", type=int, default=12, help="Number of epochs.")
parser.add_argument("--model", default="t5-small", help="Which seq2seq model.")
parser.add_argument("--add-exemplars", action="store_true", default=True, help="Whether to use add exemplars.")
parser.add_argument("--max-exemplars", type=int, default=10, help="Number of exemplars")
parser.add_argument("--decode", default="topk", help="topk or beam search decoding strategy.")
parser.add_argument("--inference", default=None, help="run_ID")
parser.add_argument("--main-loss-w", type=float, default=1.0, help="loss weight for the generative loss")
parser.add_argument("--empathy-loss-w", type=float, default=1.0, help="loss weight for the empathy loss")
parser.add_argument("--sentiment-loss-w", type=float, default=1.0, help="loss weight for the sentiment loss")
parser.add_argument("--strategy", type=int, default=0, help="logits to probability strategy for empathy/sentiment prediction model inputs")
parser.add_argument("--fixed", action="store_true", default=False, help="Whether to keep empathy and sentiment prediction models fixed.")
args = parser.parse_args()
print(args)
global max_source_length
global max_target_length
if args.inference is None:
run_ID = int(time.time())
print(f"run ID: {run_ID}")
max_source_length, max_target_length = args.src_len, args.tgt_len
batch_size = args.batch_size
n_epochs = args.epochs
model_name = args.model
strategy = args.strategy
main_loss_w = args.main_loss_w
empathy_loss_w = args.empathy_loss_w
sentiment_loss_w = args.sentiment_loss_w
train_loader, valid_loader, test_loader = configure_dataloaders(batch_size)
empathy_loss_function = torch.nn.CrossEntropyLoss().cuda()
sentiment_loss_function = torch.nn.MSELoss().cuda()
if args.inference is None:
model = ERGMainModel(model_name, max_source_length, max_target_length, strategy, args.add_exemplars, args.max_exemplars, args.fixed).cuda()
print ("Num trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = configure_optimizer(model, args)
best_loss = None
for e in range(n_epochs):
train_loss = train_or_eval_model(model, train_loader, optimizer, True, main_loss_w=main_loss_w, empathy_loss_w=empathy_loss_w, sentiment_loss_w=sentiment_loss_w)
valid_loss = train_or_eval_model(model, valid_loader, main_loss_w=main_loss_w, empathy_loss_w=empathy_loss_w, sentiment_loss_w=sentiment_loss_w)
print ("Epoch {}: train loss: {}, valid loss: {}".format(e+1, train_loss, valid_loss))
if best_loss == None or best_loss > valid_loss:
if not os.path.isdir(f"saved/{run_ID}"):
os.mkdir(f"saved/{run_ID}")
torch.save(model.state_dict(), f"saved/{run_ID}/model.pt")
best_loss = valid_loss
else:
run_ID = args.inference
model = ERGMainModel(model_name, max_source_length, max_target_length, strategy, args.add_exemplars, args.max_exemplars).cuda()
model.load_state_dict(torch.load(f"saved/{run_ID}/model.pt"))
model.eval()
test_loss = train_or_eval_model(model, test_loader, main_loss_w=main_loss_w, empathy_loss_w=empathy_loss_w, sentiment_loss_w=sentiment_loss_w)
ppl = round(np.exp(test_loss), 4)
bleu1, bleu2, references, hypothesis, utt_ids = test_model(model, test_loader, mode=args.decode)
content = "Test BLEU1: {}; BLEU2: {}; Loss: {}; Perplexity: {}; Run ID: {}; Args: {}".format(bleu1, bleu2, round(test_loss, 4), ppl, run_ID, str(args))
print (content)
with open("results/results.txt", "a") as f:
f.write(content + "\n")
with open(f"output/{run_ID}_output.tsv", "w") as f:
pd.DataFrame({
"Conv ID": [x[0] for x in utt_ids],
"Utterance Index": [x[1] for x in utt_ids],
"Context": [x[2] for x in utt_ids],
"Reference": [x[0] for x in references],
"Generated": hypothesis
}).to_csv(f, sep="\t", index=False)