-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
438 lines (362 loc) · 15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
"""Training script : based on https://github.com/google/flax/blob/main/examples/imagenet/train.py"""
import functools
import time
from typing import Any
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow as tf
from absl import logging
from clu import metric_writers
from clu import periodic_actions
from flax import jax_utils
from flax import optim
from flax import traverse_util
from flax.training import checkpoints
from flax.training import common_utils
from flax.training import train_state
from jax import lax
from jax import random
import models
import utils
from data import input_pipeline
try:
# Older Versions Compatibilty
from flax.optim import DynamicScale
except ImportError:
from flax.optim.dynamic_scale import DynamicScale
def initialized(key, image_size, model):
input_shape = (1, image_size, image_size, 3)
@jax.jit
def init(*args):
return model.init(*args, train=True)
rng_params, key = random.split(key)
rng_dropout, key = random.split(key)
rng_perturb_queries, key = random.split(key)
variables = init({'params': rng_params,
'dropout': rng_dropout,
'perturb_queries': rng_perturb_queries},
jnp.ones(input_shape, model.dtype),
)
return variables['params'], variables.get('batch_stats', None)
def cross_entropy_loss(logits, labels, num_classes=10):
if len(labels.shape) == 1:
# Convert one-hot labels to single values if appliable.
one_hot_labels = common_utils.onehot(labels, num_classes=num_classes)
else:
one_hot_labels = labels
xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
return jnp.mean(xentropy)
def compute_metrics(logits, labels, num_classes=10):
loss = cross_entropy_loss(logits, labels, num_classes)
if len(labels.shape) == 1:
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
else:
accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
metrics = {
'loss': loss,
'accuracy': accuracy,
}
metrics = lax.pmean(metrics, axis_name='batch')
return metrics
def create_learning_rate_fn(
config: ml_collections.ConfigDict,
base_learning_rate: float,
steps_per_epoch: int):
"""Create learning rate schedule."""
warmup_fn = optax.linear_schedule(
init_value=0., end_value=base_learning_rate,
transition_steps=config.warmup_epochs * steps_per_epoch)
# Get Learning rate schedule
lr_rate_schedule = config.learning_rate_schedule
if lr_rate_schedule == 'cosine':
cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)
lr_decay_fn = optax.cosine_decay_schedule(
init_value=base_learning_rate,
decay_steps=cosine_epochs * steps_per_epoch)
elif lr_rate_schedule == 'const':
lr_decay_fn = optax.constant_schedule(base_learning_rate)
else:
raise ValueError(f'Unsuported learning rate schedule {lr_rate_schedule}')
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, lr_decay_fn],
boundaries=[config.warmup_epochs * steps_per_epoch])
return schedule_fn
def train_step(state, batch, learning_rate_fn, config, rng):
"""Perform a single training step."""
dropout_rng, perturb_queries_rng = jax.random.split(rng, 2)
def loss_fn(params):
"""loss function used for training."""
variables = (
{'params': params, 'batch_stats': state.batch_stats} if state.batch_stats is not None else {'params': params})
logits, new_model_state = state.apply_fn(variables,
batch['image'],
mutable=['batch_stats'],
train=True,
rngs={'dropout': dropout_rng,
'perturb_queries': perturb_queries_rng})
loss = cross_entropy_loss(logits, batch['label'], num_classes=input_pipeline.get_num_classes_from_config(config))
if config.weight_decay > 0.0 and config.optim in ['adam', 'sgd']:
weight_penalty_params = jax.tree_leaves(params)
weight_decay = config.weight_decay
weight_l2 = sum([jnp.sum(x ** 2)
for x in weight_penalty_params
if x.ndim > 1])
weight_penalty = weight_decay * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, (new_model_state, logits)
step = state.step
dynamic_scale = state.dynamic_scale
lr = learning_rate_fn(step)
if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(
loss_fn, has_aux=True, axis_name='batch')
dynamic_scale, is_fin, aux, grads = grad_fn(state.params)
# dynamic loss takes care of averaging gradients across replicas
else:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
aux, grads = grad_fn(state.params)
# Re-use same axis_name as in the call to `pmap(...train_step...)` below.
grads = lax.pmean(grads, axis_name='batch')
new_model_state, logits = aux[1]
metrics = compute_metrics(logits, batch['label'], input_pipeline.get_num_classes_from_config(config))
metrics['learning_rate'] = lr
new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state.get('batch_stats', None))
if dynamic_scale:
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
# params should be restored (= skip this step).
new_state = new_state.replace(
opt_state=jax.tree_multimap(
functools.partial(jnp.where, is_fin),
new_state.opt_state,
state.opt_state),
params=jax.tree_multimap(
functools.partial(jnp.where, is_fin),
new_state.params,
state.params))
metrics['scale'] = dynamic_scale.scale
return new_state, metrics
def eval_step(state, batch, config):
variables = (
{'params': state.params, 'batch_stats': state.batch_stats} if state.batch_stats is not None else
{'params': state.params})
logits = state.apply_fn(
variables, batch['image'], train=False, mutable=False)
return compute_metrics(logits, batch['label'], input_pipeline.get_num_classes_from_config(config))
def prepare_tf_data(xs):
"""Convert a input batch from tf Tensors to numpy arrays."""
local_device_count = jax.local_device_count()
def _prepare(x):
# Use _numpy() for zero-copy conversion between TF and NumPy.
x = x._numpy() # pylint: disable=protected-access
# reshape (host_batch_size, height, width, 3) to
# (local_devices, device_batch_size, height, width, 3)
if len(x.shape) == 4:
return x.reshape((local_device_count, -1) + x.shape[1:])
else:
return x
return jax.tree_map(_prepare, xs)
def create_input_iter(dataset_builder, config, rng, batch_size, train, image_size, dtype, cache):
ds = input_pipeline.create_split(
dataset_builder,
config,
rng,
batch_size,
train=train,
image_size=image_size,
dtype=dtype,
cache=cache)
it = map(prepare_tf_data, ds)
it = jax_utils.prefetch_to_device(it, 10)
return it
class TrainState(train_state.TrainState):
batch_stats: Any
dynamic_scale: DynamicScale
def restore_checkpoint(state, workdir):
return checkpoints.restore_checkpoint(workdir, state)
def save_checkpoint(state, workdir):
if jax.process_index() == 0:
# get train state from the first replica
state = jax.device_get(jax.tree_map(lambda x: x[0], state))
step = int(state.step)
checkpoints.save_checkpoint(workdir, state, step, keep=3)
# pmean only works inside pmap because it needs an axis name.
# This function will average the inputs across all devices.
cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
def sync_batch_stats(state):
"""Sync the batch statistics across replicas."""
# Each device has its own version of the running average batch statistics and
# we sync them before evaluation.
if state.batch_stats is not None:
return state.replace(batch_stats=cross_replica_mean(state.batch_stats))
else:
return state
def create_train_state(rng, config: ml_collections.ConfigDict,
model, image_size, learning_rate_fn):
"""Create initial training state."""
dynamic_scale = None
platform = jax.local_devices()[0].platform
if config.half_precision and platform == 'gpu':
dynamic_scale = optim.DynamicScale()
else:
dynamic_scale = None
params, batch_stats = initialized(rng, image_size, model)
params_dict = params.unfreeze()
del params
params = params_dict
# TODO(marar): wrap in util function:
def _flattened_traversal(fn):
def mask(data):
flat = traverse_util.flatten_dict(data)
return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
return mask
if config.optim == 'sgd':
tx = optax.sgd(
learning_rate=learning_rate_fn,
momentum=config.momentum,
nesterov=True,
)
elif config.optim == 'adam':
tx = optax.adam(learning_rate=learning_rate_fn)
elif config.optim == 'adamw':
tx = optax.chain(
optax.clip_by_global_norm(1.0) if config.grad_clip_max_norm else optax.identity(),
optax.masked(optax.adamw(learning_rate=learning_rate_fn, weight_decay=config.weight_decay),
mask=_flattened_traversal(lambda path, _: path[-1] not in config.optim_wd_ignore)),
optax.masked(optax.adamw(learning_rate=learning_rate_fn, weight_decay=0.0),
mask=_flattened_traversal(lambda path, _: path[-1] in config.optim_wd_ignore)),
)
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
batch_stats=batch_stats,
dynamic_scale=dynamic_scale)
return state
def train_and_evaluate(config: ml_collections.ConfigDict,
workdir: str) -> TrainState:
"""Execute model training and evaluation loop.
Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.
Returns:
Final TrainState.
"""
if config.seed < 0:
islocked = config.is_locked
if islocked:
config.unlock()
config.seed = np.random.randint(0, 1000000)
if islocked:
config.lock()
writer = metric_writers.create_default_writer(
logdir=workdir, just_logging=jax.process_index() != 0)
rng = random.PRNGKey(config.seed)
image_size = config.input_size
if config.batch_size % jax.device_count() > 0:
raise ValueError('Batch size must be divisible by the number of devices')
local_batch_size = config.batch_size // jax.local_device_count()
platform = jax.local_devices()[0].platform
if config.half_precision:
if platform == 'tpu':
input_dtype = tf.bfloat16
else:
input_dtype = tf.float16
else:
input_dtype = tf.float32
# Build input pipeline.
dataset_builder = input_pipeline.get_dataset_builder(config)
rng, data_rng = jax.random.split(rng)
data_rng = jax.random.fold_in(data_rng, jax.process_index())
data_rng_train, data_rng_val = jax.random.split(data_rng)
train_iter = create_input_iter(
dataset_builder, config, data_rng_train, batch_size=local_batch_size, train=True, image_size=image_size,
dtype=input_dtype,
cache=config.cache)
# TODO(marar): data_rng_val is not necessary!
eval_iter = create_input_iter(
dataset_builder, config, data_rng_val, batch_size=local_batch_size, train=False, image_size=image_size,
dtype=input_dtype,
cache=config.cache)
steps_per_epoch = (
dataset_builder.info.splits['train'].num_examples // config.batch_size
)
if config.num_train_steps == -1:
num_steps = int(steps_per_epoch * config.num_epochs)
else:
num_steps = config.num_train_steps
if config.steps_per_eval == -1:
num_validation_examples = input_pipeline.get_num_eval_examples(dataset_builder, config)
steps_per_eval = num_validation_examples // config.batch_size
else:
steps_per_eval = config.steps_per_eval
steps_per_checkpoint = steps_per_epoch * config.checkpoint_every_epochs
base_learning_rate = config.learning_rate * config.batch_size / 256.
num_classes = dataset_builder.info.features['label'].num_classes
model = models.create_model(config=config, num_classes=num_classes)
learning_rate_fn = create_learning_rate_fn(
config, base_learning_rate, steps_per_epoch)
state = create_train_state(rng, config, model, image_size, learning_rate_fn)
if config.get('fine_tune', False):
state = utils.init_from_pretrained(state, config)
state = restore_checkpoint(state, workdir)
# step_offset > 0 if restarting from checkpoint
step_offset = int(state.step)
state = jax_utils.replicate(state)
p_train_step = jax.pmap(
functools.partial(train_step, learning_rate_fn=learning_rate_fn, config=config), axis_name='batch')
p_eval_step = jax.pmap(functools.partial(eval_step, config=config), axis_name='batch')
rng, drop_out_rng = jax.random.split(rng, 2)
drop_out_rng = jax.random.fold_in(drop_out_rng, jax.process_index())
train_metrics = []
hooks = []
if jax.process_index() == 0:
hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)]
train_metrics_last_t = time.time()
logging.info('Initial compilation, this might take some minutes...')
for step, batch in zip(range(step_offset, num_steps), train_iter):
drop_out_rng_step = jax.random.fold_in(drop_out_rng, step)
drop_out_rng_step_all = jax.random.split(drop_out_rng_step,
jax.local_device_count())
state, metrics = p_train_step(state, batch, rng=drop_out_rng_step_all)
for h in hooks:
h(step)
if step == step_offset:
logging.info('Initial compilation completed.')
if config.log_every_steps:
train_metrics.append(metrics)
if (step + 1) % config.log_every_steps == 0:
train_metrics = common_utils.get_metrics(train_metrics)
summary = {
f'train_{k}': v
for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
}
summary['steps_per_second'] = config.log_every_steps / (
time.time() - train_metrics_last_t)
writer.write_scalars(step + 1, summary)
train_metrics = []
train_metrics_last_t = time.time()
if (step + 1) % steps_per_epoch == 0:
epoch = step // steps_per_epoch
eval_metrics = []
# sync batch statistics across replicas
state = sync_batch_stats(state)
for _ in range(steps_per_eval):
eval_batch = next(eval_iter)
metrics = p_eval_step(state, eval_batch)
eval_metrics.append(metrics)
eval_metrics = common_utils.get_metrics(eval_metrics)
summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
epoch, summary['loss'], summary['accuracy'] * 100)
writer.write_scalars(
step + 1, {f'eval_{key}': val for key, val in summary.items()})
writer.flush()
if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
state = sync_batch_stats(state)
save_checkpoint(state, workdir)
# Wait until computations are done before exiting
jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
return state