-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxmodel.py
348 lines (294 loc) · 14.9 KB
/
xmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import timm
import torch
import torch.optim as optim
from train_utils import group_weight, count_correct, is_bn, eval_param_norm, yield_optimizer_state
from distributed_utils import get_flat_tensor_from_tensor_sequence, is_main_process, get_flat_grad_from, reduce_value, set_flat_grad_to, set_flat_tensor_to_tensor_sequence
from torch.optim.lr_scheduler import LinearLR
from torch.cuda.amp import autocast
import torch.nn as nn
import wandb
from grad_scaler import KGradScaler, GradScaleTooLargeError
from localopt import LocalOptimizer
import os
import composer.functional as cf
import numpy as np
class XModel():
model: nn.Module
acc_ctr: int = 0
m: float = 0.
v: float = 0.
m1: float = 0.
v1: float = 0.
step_ctr: int = 0
def __init__(self, model, args) -> None:
self.args = args
self.model = model
self.warmup = args.warmup
self.device = args.device
self.cur_stats = torch.zeros(8, device=self.device)
self.mixing = 0
self.batch_target_perm = None
self.copied_grad = {}
if self.args.resume_pth is not None:
self.step_ctr = self.args.resume_from_step
if args.grad_scaler:
self.grad_scaler = KGradScaler(init_scale=args.grad_upscale, growth_factor=args.grad_scaler_growth_factor, backoff_factor=args.grad_scaler_backoff_factor, growth_interval=args.grad_scaler_growth_interval)
def create_optimizer(self, refer_lr, **kwargs):
if_group_weight = self.args.group_weight
if if_group_weight:
model_param = group_weight(self.model)
if is_main_process():
print("Grouping weight")
else:
model_param = self.model.parameters()
if self.args.optimizer == 'sgd':
self.optimizer = optim.SGD(
model_param, lr=refer_lr, weight_decay=self.args.wd, momentum=self.args.momentum, nesterov=self.args.nesterov
)
elif self.args.optimizer == 'adamw':
self.optimizer = optim.AdamW(
model_param, lr=refer_lr, weight_decay=self.args.wd, betas=(self.args.beta1, self.args.beta2), eps=self.args.eps
)
elif self.args.optimizer == 'localadamw':
optimizer = optim.AdamW(model_param, lr=refer_lr, weight_decay=self.args.wd, betas=(self.args.beta1, self.args.beta2), eps=self.args.eps)
fields_to_avg = []
if self.args.avg_m:
fields_to_avg.append('exp_avg')
if self.args.avg_v:
fields_to_avg.append('exp_avg_sq')
self.optimizer = LocalOptimizer(optim=optimizer, warmup_steps=kwargs['warmup_steps'], total_steps = kwargs['total_steps'], alpha=self.args.alpha, power=self.args.power, min_h=self.args.min_h, init_h=self.args.init_h, step_ctr=self.step_ctr, optim_fields_to_average=fields_to_avg )
elif self.args.optimizer == 'localsgd':
optimizer = optim.SGD(model_param, lr=refer_lr, weight_decay=self.args.wd, momentum=self.args.momentum, nesterov=self.args.nesterov)
self.optimizer = LocalOptimizer(optim=optimizer, warmup_steps=kwargs['warmup_steps'], total_steps = kwargs['total_steps'], alpha=self.args.alpha, power=self.args.power, min_h=self.args.min_h, init_h=self.args.init_h, optim_fields_to_average=[] )
else:
raise NotImplementedError
if is_main_process():
print(f"Optimizer: {self.args.optimizer}")
def get_local_step(self):
if 'local' not in self.args.optimizer:
return 1
else:
return self.optimizer.get_local_step()
def load_optimizer_state(self):
if self.args.optimizer_resume_pth is not None:
if not self.args.multiple_optimizers:
opt_state_dict = torch.load(self.args.optimizer_resume_pth, map_location=self.device)
self.optimizer.load_state_dict(opt_state_dict)
print(opt_state_dict['state'].keys())
else:
pth = os.path.join(self.args.optimizer_resume_pth, f'rank={self.args.rank}.pt')
opt_state_dict = torch.load(pth, map_location=self.device)
self.optimizer.load_state_dict(opt_state_dict)
def get_optimizer_state_norm(self):
# if self.args.optimizer == 'adamw':
# opt = self.optimizer
# else:
# raise NotImplementedError
m_vec = get_flat_tensor_from_tensor_sequence(yield_optimizer_state(model=self.model, optimizer=self.optimizer, key='exp_avg'))
v_vec = get_flat_tensor_from_tensor_sequence(yield_optimizer_state(model=self.model, optimizer=self.optimizer, key='exp_avg_sq'))
return torch.norm(m_vec), torch.norm(v_vec), torch.norm(m_vec, p=1), torch.norm(v_vec, p=1), {"v1000": v_vec[1000], "v50000": v_vec[50000], "v100000": v_vec[100000]}
def update_step(self, batch_image, batch_target, criterion, acc_times):
def step(input_img, input_target, **kwargs):
self.optimizer.zero_grad()
with autocast(dtype=self.args.dtype):
output = self.model(input_img)
if self.args.strong_aug:
loss_train = (1 - kwargs['mixing']) * criterion(output, input_target) + kwargs['mixing'] * criterion(output, kwargs['batch_target_perm'])
else:
loss_train = criterion(output, input_target)
loss_train /= acc_times
rescaled_loss = loss_train * self.grad_scaler.scale if self.args.grad_scaler else loss_train
rescaled_loss.backward()
if self.args.grad_scaler:
self.grad_scaler.unscale_(self.optimizer)
return output, loss_train
def optimizer_step():
if 'local' in self.args.optimizer:
averaged = self.optimizer.step()
else:
averaged = True
self.optimizer.step()
if self.args.grad_scaler:
self.grad_scaler.update()
return averaged
def log_step():
if 'local' in self.args.optimizer:
h = self.optimizer.get_local_step()
else:
h = 1
wandb_dict = {}
if self.args.log_per_step and is_main_process():
wandb_dict = {"train_step": self.step_ctr,
"train_step/loss":self.cur_stats[2]/self.cur_stats[3],
"train_step/acc1": self.cur_stats[0]/self.cur_stats[3],
"train_step/acc5": self.cur_stats[1]/self.cur_stats[3],
"train_step/lr": self.optimizer.param_groups[0]['lr'],
"train_step/h": h,
"train_step/m": self.m,
"train_step/v": self.v,
"train_step/m1": self.m1,
"train_step/v1": self.v1,
"train_step/grad_norm": self.grad_norm,
"train_step/param_norm": eval_param_norm(self.model)
}
if 'local' in self.args.optimizer:
wandb_dict.update({"train_step/lr_inside": self.optimizer.optim.param_groups[0]['lr']})
if "adam" in self.args.optimizer:
wandb_dict.update(self.v_samples)
if self.args.grad_scaler:
wandb_dict.update({'train_step/grad_scaler': self.grad_scaler.scale})
return wandb_dict
def count_step(output, loss_train):
with torch.no_grad():
train_correct1, train_correct5 = count_correct(
output=output,
target=batch_target,
topk=(1,5)
)
if self.args.strong_aug:
train_correct1_perm, train_correct5_perm = count_correct(
output=output,
target=self.batch_target_perm,
topk=(1,5)
)
train_correct1 = (1 - self.mixing) * train_correct1 + self.mixing * train_correct1_perm
train_correct5 = (1 - self.mixing) * train_correct5 + self.mixing * train_correct5_perm
cur_stats = torch.stack([
train_correct1, train_correct5, loss_train * batch_image.shape[0] * acc_times,
torch.as_tensor(batch_image.shape[0], dtype=loss_train.dtype, device=loss_train.device),
torch.as_tensor(0., dtype=loss_train.dtype, device=loss_train.device),
torch.as_tensor(0., dtype=loss_train.dtype, device=loss_train.device),
torch.as_tensor(0., dtype=loss_train.dtype, device=loss_train.device),
torch.as_tensor(0., dtype=loss_train.dtype, device=loss_train.device)
])
return cur_stats
#begin gradient step
if not self.model.training: # optimize for speed
self.model.train()
averaged = None
kwdct = {}
input_target = batch_target
# generate permuted batch
if self.args.strong_aug:
batch_image_perm, batch_target_perm, mixing = cf.mixup_batch(batch_image, batch_target, alpha=self.args.mixup_alpha)
self.batch_target_perm = batch_target_perm
self.mixing = mixing
kwdct['batch_target_perm'] = batch_target_perm
kwdct['mixing'] = mixing
input_img = batch_image_perm
else:
input_img = batch_image
if self.args.grad_scaler:
success = False
for t in range(self.args.grad_scaler_max_retries):
try:
output, loss_train = step(input_img=input_img, input_target=input_target, **kwdct)
success = True
break
except GradScaleTooLargeError:
pass
if not success:
raise ValueError("Cannot find grad_scaler!")
else:
output, loss_train = step(input_img=input_img, input_target=input_target, **kwdct)
torch.cuda.synchronize()
cur_stats = count_step(output, loss_train)
self.cur_stats += cur_stats
with torch.no_grad():
for name, param in self.model.named_parameters():
if param.requires_grad:
if name in self.copied_grad:
self.copied_grad[name] += param.grad.clone()
else:
self.copied_grad[name] = param.grad.clone()
self.acc_ctr += 1
# if is_main_process():
# print(f"acc_ctr {self.acc_ctr}, train step {self.step_ctr}", )
wandb_dict = {}
if self.acc_ctr == acc_times:
# Average gradients
with torch.no_grad():
flat_grad = torch.cat([grad.view(-1) for grad in self.copied_grad.values()])
# if not local methods, average gradients among all gpus
if 'local' not in self.args.optimizer:
flat_grad = reduce_value(flat_grad, average=True, group=None)
torch.cuda.synchronize()
set_flat_grad_to(self.model, flat_grad)
self.grad_norm = torch.norm(flat_grad)
# clip the gradients
if not np.isinf(self.args.gradient_clipping):
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clipping)
averaged = optimizer_step()
self.step_ctr += 1
if "adam" in self.args.optimizer and self.args.log_per_step:
self.m, self.v, self.m1, self.v1, self.v_samples = self.get_optimizer_state_norm()
wandb_dict = log_step()
if is_main_process() and self.step_ctr <= 300 and self.args.debug:
print(f"step: {self.step_ctr}, lr: {self.optimizer.param_groups[0]['lr']}")
#, communication time for gradients {time_comm - time_opt_comm}, communication time for optimizer {time_opt_comm}"
# reset ctr and stats
cur_stats[4] = self.m
cur_stats[5] = self.v
cur_stats[6] = self.m1
cur_stats[7] = self.v1
self.acc_ctr = 0
self.cur_stats = torch.zeros(8, device=self.device)
self.copied_grad = {}
return cur_stats, averaged, wandb_dict
def save_model_state_dict(self, pth):
if self.model.training:
self.model.eval()
torch.save(self.model.state_dict(), pth)
def save_optimizer_state_dict(self, pth):
torch.save(self.optimizer.state_dict(), pth)
@torch.no_grad()
def eval_step(self, val_loader, criterion):
if self.model.training: # optimize for speed
self.model.eval()
val_stats = torch.zeros(4, device=self.device)
for images, targets in val_loader:
with autocast(dtype=self.args.dtype):
output = self.model(images)
# loss_val = criterion(output, targets)
# skip evaluating val loss
val_correct1, val_correct5 = count_correct(
output=output,
target=targets,
topk=(1,5)
)
loss_val = torch.zeros_like(val_correct1)
val_stats += torch.stack([
val_correct1, val_correct5, loss_val * images.shape[0],
torch.as_tensor(images.shape[0], dtype=loss_val.dtype, device=loss_val.device)
])
torch.cuda.synchronize()
val_stats = reduce_value(val_stats, average=False)
ret = val_stats[:3] / val_stats[3]
return ret
def update_bn(self, idx, images):
for m in self.model.modules():
if is_bn(m):
m.momentum = 1 / (1 + idx)
with torch.no_grad():
with autocast(dtype=self.args.dtype):
self.model(images)
def buffers_to_average(self):
for name, buffer in self.model.named_buffers():
if name.endswith("mean") or name.endswith("var"):
yield buffer
@torch.no_grad()
def estimate_BN_params(self, bn_loader):
if is_main_process():
print("Estimating BN")
if not self.model.training: # optimize for speed
self.model.train()
bn_loader_iter = iter(bn_loader)
for idx, (images, targets) in enumerate(bn_loader_iter):
if idx >= self.args.bn_batches // self.args.world_size:
break
self.update_bn(idx, images)
torch.cuda.synchronize()
bn_loader_iter.close()
flat = get_flat_tensor_from_tensor_sequence(self.buffers_to_average())
flat = reduce_value(flat, average=True)
set_flat_tensor_to_tensor_sequence(flat, self.buffers_to_average())