Skip to content

Commit

Permalink
Merge pull request #1130 from carlosgmartin:fix_adabelief
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695382019
  • Loading branch information
OptaxDev committed Nov 11, 2024
2 parents a3d9683 + 723e6bf commit db6db9f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def init_fn(params):
def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, b1, 1)
prediction_error = jax.tree.map(lambda g, m: g - m, updates, state.mu)
prediction_error = otu.tree_sub(updates, mu)
nu = otu.tree_update_moment_per_elem_norm(prediction_error, state.nu, b2, 2)
nu = jax.tree.map(lambda v: v + eps_root, nu)
count_inc = numerics.safe_increment(state.count)
Expand Down

0 comments on commit db6db9f

Please sign in to comment.