Logging gradients norms before clipping #1026
-
Is there an easy way to log the norm of the gradients before grad_accum_steps = 4
optim = optax.chain(optax.adamw(1e-5, mask=mask), optax.clip_by_global_norm(1.0))
optim = optax.MultiSteps(optim, every_k_schedule=grad_accum_steps) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You may simply create a custom GradientTransform that does not touch the update, just computes the norm, put it in the state (or even print it if you want). Then you do the usual chain except that you insert that custom transform just before the clipping. You may then fetch the gradient norm from the overall state using optax.tree_utils.tree_get Something along the following lines def record_norm(): def update_fn(updates, state, params=None): return optax.GradientTransformation(init_fn, update_fn) |
Beta Was this translation helpful? Give feedback.
You may simply create a custom GradientTransform that does not touch the update, just computes the norm, put it in the state (or even print it if you want). Then you do the usual chain except that you insert that custom transform just before the clipping.
You may then fetch the gradient norm from the overall state using optax.tree_utils.tree_get
Something along the following lines
´´´
class RecordNormState(typing.NamedTuple):
grad_norm: jax.Array
def record_norm():
def init_fn(params)
return RecordNormState(grad_norm=jax.as_array(0))
def update_fn(updates, state, params=None):
return updates, RecordNormState(grad_norm=optax.tree_utils.tree_l2_norm(updates))
return optax.GradientTransforma…