-
Notifications
You must be signed in to change notification settings - Fork 3
/
engine.py
247 lines (208 loc) · 8.83 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import os
import time
from collections import OrderedDict
from contextlib import suppress
import torch
import torchvision.utils
from timm import utils
from timm.models import model_parameters
def train_one_epoch(
epoch,
model,
loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None,
_logger=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
has_no_sync = hasattr(model, "no_sync")
update_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
accum_steps = args.grad_accum_steps
last_accum_steps = len(loader) % accum_steps
updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
num_updates = epoch * updates_per_epoch
last_batch_idx = len(loader) - 1
last_batch_idx_to_accum = len(loader) - last_accum_steps
data_start_time = update_start_time = time.time()
optimizer.zero_grad()
update_sample_count = 0
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps
if not args.prefetcher:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if loss_scaler is not None:
loss_scaler(
_loss,
optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order,
need_update=need_update,
)
else:
_loss.backward(create_graph=second_order)
if need_update:
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode,
)
optimizer.step()
# for name, param in model.named_parameters():
# assert torch.isfinite(param).all() == True, f"Param {param} not a number"
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
# for name, param in model.named_parameters():
# assert torch.isfinite(param.grad).all() == True, f"optim {param.grad} not a number"
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
data_start_time = time.time()
continue
num_updates += 1
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
if update_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
update_sample_count *= args.world_size
if utils.is_primary(args):
_logger.info(
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True
)
if saver is not None and args.recovery_interval and (
(update_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=update_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
update_sample_count = 0
data_start_time = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix='',
_logger=None,
):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
target = target.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
acc1 = utils.reduce_tensor(acc1, args.world_size)
acc5 = utils.reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
if device.type == 'cuda':
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics