How to make an Optax optimizer in Flax? #994
-
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Answer by @levskaya: "below is a demo how to wrap an Optax optimizer in Flax. Be aware that Flax optimizers return updated params, whereas optax gradient transforms return updated gradients instead. So there's a slightly tricky, optimizer-dependent from functools import partial
import numpy as np
import jax
from jax import random, lax, numpy as jnp
import flax
from flax import linen as nn
from flax import optim
import optax
from optax import GradientTransformation
class Foo(nn.Module):
filters: int
@nn.compact
def __call__(self, x):
x = nn.Conv(self.filters, (2, 2), use_bias=False, dtype=jnp.float32)(x)
x = nn.relu(x)
x = jnp.reshape(x, (5, -1))
x = nn.Dense(3)(x)
return x
key = random.PRNGKey(0)
x = random.uniform(key, (5,4,4,3))
tgt = random.randint(key, (5,), 0, 3)
# init variables
foo_vars = Foo(filters=3).init(key, x)
# grab a gradient for params
def xent(targets, logits):
return -jnp.sum(targets * nn.log_softmax(logits), axis=-1)
def lossfn(p):
y = Foo(filters=3).apply({'params': p}, x)
tgt_oh = jax.nn.one_hot(tgt, 3)
return jnp.sum(xent(tgt_oh, y))
grads = jax.grad(lossfn)(foo_vars['params']) The biggest difference is that optax transforms return updated "grads" rather than updated params. def make_optax_from_flax_optim(opt_def):
def init(params):
return opt_def.create(params).state
def update(grads, state, params):
new_params, new_state = opt_def.apply_gradient(opt_def.hyper_params, params, state, grads)
# hacky, but we have to recover gradient from param delta and lr
# keep lr passed to the flax optimizer = 1.0 to prevent precision loss.
def grads_from_params(lr, old, new):
return (new - old) / lr
gfun = partial(grads_from_params, opt_def.hyper_params.learning_rate)
new_grads = jax.tree_multimap(gfun, params, new_params)
return new_grads, new_state
return GradientTransformation(init, update)
flax_opt = optim.Adam(1.0)
optax_opt = make_optax_from_flax_optim(optim.Adam(1.0))
# flax adam acting like an optax adam
optax_state = optax_opt.init(foo_vars['params'])
new_grads, optax_state = optax_opt.update(grads, optax_state, foo_vars['params'])
# "real" optax adam
ref_state = optax.adam(1.0).init(foo_vars['params'])
ref_grads, ref_state = optax.adam(1.0).update(grads, ref_state, foo_vars['params'])
# check difference between optax and wrapped-flax adam:
jax.tree_map(jnp.sum, jax.tree_multimap(lambda x,y: x-y, new_grads, ref_grads)) This returns:
|
Beta Was this translation helpful? Give feedback.
Answer by @levskaya: "below is a demo how to wrap an Optax optimizer in Flax. Be aware that Flax optimizers return updated params, whereas optax gradient transforms return updated gradients instead. So there's a slightly tricky, optimizer-dependent
new_params, old_params --> delta
calculation that you need to do, and to be safe you want to make sure the learning rate fed to this wrapped optimizer is ~1.0 so you don't get bad cancellation in reverting back to the grad from subtraction."