-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain.py
355 lines (297 loc) · 13.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import os
import torch
import torchvision
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.distributed import all_reduce, all_gather, ReduceOp
from utils.func import *
from modules.loss import *
from modules.scheduler import *
def train(cfg, model, train_dataset, val_dataset, estimator, logger=None):
if cfg.base.HPO:
from nni import report_intermediate_result, report_final_result
device = cfg.base.device
optimizer = initialize_optimizer(cfg, model)
train_sampler, val_sampler = initialize_sampler(cfg, train_dataset, val_dataset)
lr_scheduler, warmup_scheduler = initialize_lr_scheduler(cfg, optimizer)
loss_function, loss_weight_scheduler = initialize_loss(cfg, train_dataset)
train_loader, val_loader = initialize_dataloader(cfg, train_dataset, val_dataset, train_sampler, val_sampler)
# start training
model.train()
avg_loss = 0
max_indicator = -999
for epoch in range(1, cfg.train.epochs + 1):
# resampling weight update
if cfg.dist.distributed:
train_sampler.set_epoch(epoch)
elif train_sampler:
train_sampler.step()
# update loss weights
if loss_weight_scheduler:
weight = loss_weight_scheduler.step()
loss_function.weight = weight.to(device)
# warmup scheduler update
if warmup_scheduler and not warmup_scheduler.is_finish():
warmup_scheduler.step()
epoch_loss = 0
estimator.reset()
progress = tqdm(enumerate(train_loader), total=len(train_loader)) if cfg.base.progress else enumerate(train_loader)
for step, train_data in progress:
X, y = train_data
X = X.cuda(cfg.dist.gpu) if cfg.dist.distributed else X.to(device)
y = y.cuda(cfg.dist.gpu) if cfg.dist.distributed else y.to(device)
y = select_target_type(y, cfg.train.criterion)
# forward
y_pred = model(X)
loss = loss_function(y_pred, y)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# metrics
if cfg.dist.distributed:
all_reduce(loss, ReduceOp.SUM)
loss = loss / cfg.dist.world_size
y_pred_list = [torch.zeros_like(y_pred) for _ in range(cfg.dist.world_size)]
y_list = [torch.zeros_like(y) for _ in range(cfg.dist.world_size)]
all_gather(y_pred_list, y_pred)
all_gather(y_list, y)
y_pred = torch.cat(y_pred_list, dim=0)
y = torch.cat(y_list, dim=0)
if is_main(cfg):
epoch_loss += loss.item()
avg_loss = epoch_loss / (step + 1)
estimator.update(y_pred, y)
message = 'epoch: [{} / {}], loss: {:.6f}'.format(epoch, cfg.train.epochs, avg_loss)
if cfg.base.progress:
progress.set_description(message)
if is_main(cfg) and not cfg.base.progress:
print(message)
if is_main(cfg):
train_scores = estimator.get_scores(4)
scores_txt = ', '.join(['{}: {}'.format(metric, score) for metric, score in train_scores.items()])
print('Training metrics:', scores_txt)
curr_lr = optimizer.param_groups[0]['lr']
if is_main(cfg) and logger:
for metric, score in train_scores.items():
logger.add_scalar('training {}'.format(metric), score, epoch)
logger.add_scalar('training loss', avg_loss, epoch)
logger.add_scalar('learning rate', curr_lr, epoch)
if is_main(cfg) and cfg.train.sample_view:
samples = torchvision.utils.make_grid(X)
samples = inverse_normalize(samples, cfg.data.mean, cfg.data.std)
logger.add_image('input samples', samples, epoch, dataformats='CHW')
# validation performance
if epoch % cfg.train.eval_interval == 0:
eval(cfg, model, val_loader, cfg.train.criterion, estimator, device)
val_scores = estimator.get_scores(6)
scores_txt = ['{}: {}'.format(metric, score) for metric, score in val_scores.items()]
print_msg('Validation metrics:', scores_txt)
if is_main(cfg) and logger:
for metric, score in val_scores.items():
logger.add_scalar('validation {}'.format(metric), score, epoch)
# save model
indicator = val_scores[cfg.train.indicator]
if is_main(cfg) and indicator > max_indicator:
save_weights(model, os.path.join(cfg.base.save_path, 'best_validation_weights.pt'))
max_indicator = indicator
print_msg('Best {} in validation set. Model save at {}'.format(cfg.train.indicator, cfg.base.save_path))
if is_main(cfg) and cfg.base.HPO:
report_intermediate_result(indicator)
if is_main(cfg) and epoch % cfg.train.save_interval == 0:
save_weights(model, os.path.join(cfg.base.save_path, 'epoch_{}.pt'.format(epoch)))
# update learning rate
if lr_scheduler and (not warmup_scheduler or warmup_scheduler.is_finish()):
if cfg.solver.lr_scheduler == 'reduce_on_plateau':
lr_scheduler.step(avg_loss)
else:
lr_scheduler.step()
# save final model
if is_main(cfg):
save_weights(model, os.path.join(cfg.base.save_path, 'final_weights.pt'))
if is_main(cfg) and cfg.base.HPO:
report_final_result(indicator)
if is_main(cfg) and logger:
logger.close()
def evaluate(cfg, model, test_dataset, estimator):
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if cfg.dist.distributed else None
test_loader = DataLoader(
test_dataset,
shuffle=(test_sampler is None),
sampler=test_sampler,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.num_workers,
pin_memory=cfg.train.pin_memory
)
print('Running on Test set...')
eval(cfg, model, test_loader, cfg.train.criterion, estimator, cfg.base.device)
if is_main(cfg):
print('================Finished================')
test_scores = estimator.get_scores(6)
for metric, score in test_scores.items():
print('{}: {}'.format(metric, score))
print('Confusion Matrix:')
print(estimator.get_conf_mat())
print('========================================')
def eval(cfg, model, dataloader, criterion, estimator, device):
model.eval()
torch.set_grad_enabled(False)
estimator.reset()
for test_data in dataloader:
X, y = test_data
X = X.cuda(cfg.dist.gpu) if cfg.dist.distributed else X.to(device)
y = y.cuda(cfg.dist.gpu) if cfg.dist.distributed else y.to(device)
y = select_target_type(y, criterion)
y_pred = model(X)
if cfg.dist.distributed:
y_pred_list = [torch.zeros_like(y_pred) for _ in range(cfg.dist.world_size)]
y_list = [torch.zeros_like(y) for _ in range(cfg.dist.world_size)]
all_gather(y_pred_list, y_pred)
all_gather(y_list, y)
y_pred = torch.cat(y_pred_list, dim=0)
y = torch.cat(y_list, dim=0)
estimator.update(y_pred, y)
model.train()
torch.set_grad_enabled(True)
# define weighted_sampler
def initialize_sampler(cfg, train_dataset, val_dataset):
sampling_strategy = cfg.data.sampling_strategy
if cfg.dist.distributed:
if sampling_strategy != 'instance_balanced':
msg = 'Resampling is not allowed when distributed parallel is applied. \
Please set sampling_strategy to instance_balanced.'
exit_with_error(msg)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=cfg.dist.world_size,
rank=cfg.dist.rank
)
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset,
num_replicas=cfg.dist.world_size,
rank=cfg.dist.rank
)
else:
val_sampler = None
if sampling_strategy == 'class_balanced':
train_sampler = ScheduledWeightedSampler(train_dataset, 1)
elif sampling_strategy == 'progressively_balanced':
train_sampler = ScheduledWeightedSampler(train_dataset, cfg.data.sampling_weights_decay_rate)
elif sampling_strategy == 'instance_balanced':
train_sampler = None
else:
raise NotImplementedError('Not implemented resampling strategy.')
return train_sampler, val_sampler
# define data loader
def initialize_dataloader(cfg, train_dataset, val_dataset, train_sampler, val_sampler):
batch_size = cfg.train.batch_size
num_workers = cfg.train.num_workers
pin_memory = cfg.train.pin_memory
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=num_workers,
drop_last=True,
pin_memory=pin_memory
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=(val_sampler is None),
sampler=val_sampler,
num_workers=num_workers,
drop_last=False,
pin_memory=pin_memory
)
return train_loader, val_loader
# define loss and loss weights scheduler
def initialize_loss(cfg, train_dataset):
criterion = cfg.train.criterion
criterion_args = cfg.criterion_args[criterion]
weight = None
loss_weight_scheduler = None
loss_weight = cfg.train.loss_weight
if criterion == 'cross_entropy':
if loss_weight == 'balance':
loss_weight_scheduler = LossWeightsScheduler(train_dataset, 1)
elif loss_weight == 'dynamic':
loss_weight_scheduler = LossWeightsScheduler(train_dataset, cfg.train.loss_weight_decay_rate)
elif isinstance(loss_weight, list):
assert len(loss_weight) == len(train_dataset.classes)
weight = torch.as_tensor(loss_weight, dtype=torch.float32, device=cfg.base.device)
loss = nn.CrossEntropyLoss(weight=weight, **criterion_args)
elif criterion == 'mean_square_error':
loss = nn.MSELoss(**criterion_args)
elif criterion == 'mean_absolute_error':
loss = nn.L1Loss(**criterion_args)
elif criterion == 'smooth_L1':
loss = nn.SmoothL1Loss(**criterion_args)
elif criterion == 'kappa_loss':
loss = KappaLoss(**criterion_args)
elif criterion == 'focal_loss':
loss = FocalLoss(**criterion_args)
else:
raise NotImplementedError('Not implemented loss function.')
loss_function = WarpedLoss(loss, criterion)
return loss_function, loss_weight_scheduler
# define optmizer
def initialize_optimizer(cfg, model):
optimizer_strategy = cfg.solver.optimizer
learning_rate = cfg.solver.learning_rate
weight_decay = cfg.solver.weight_decay
momentum = cfg.solver.momentum
nesterov = cfg.solver.nesterov
adamw_betas = cfg.solver.adamw_betas
if optimizer_strategy == 'SGD':
optimizer = torch.optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
nesterov=nesterov,
weight_decay=weight_decay
)
elif optimizer_strategy == 'ADAM':
optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
elif optimizer_strategy == 'ADAMW':
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
betas=adamw_betas,
weight_decay=weight_decay
)
else:
raise NotImplementedError('Not implemented optimizer.')
return optimizer
# define learning rate scheduler
def initialize_lr_scheduler(cfg, optimizer):
warmup_epochs = cfg.train.warmup_epochs
learning_rate = cfg.solver.learning_rate
scheduler_strategy = cfg.solver.lr_scheduler
if not scheduler_strategy:
lr_scheduler = None
else:
scheduler_args = cfg.scheduler_args[scheduler_strategy]
if scheduler_strategy == 'cosine':
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_args)
elif scheduler_strategy == 'multiple_steps':
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, **scheduler_args)
elif scheduler_strategy == 'reduce_on_plateau':
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_args)
elif scheduler_strategy == 'exponential':
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_args)
elif scheduler_strategy == 'clipped_cosine':
lr_scheduler = ClippedCosineAnnealingLR(optimizer, **scheduler_args)
else:
raise NotImplementedError('Not implemented learning rate scheduler.')
if warmup_epochs > 0:
warmup_scheduler = WarmupLRScheduler(optimizer, warmup_epochs, learning_rate)
else:
warmup_scheduler = None
return lr_scheduler, warmup_scheduler