Skip to content

Commit

Permalink
Fix AdaBelief implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Nov 9, 2024
1 parent c4fd723 commit 723e6bf
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def init_fn(params):
def update_fn(updates, state, params=None):
del params
mu = otu.tree_update_moment(updates, state.mu, b1, 1)
prediction_error = jax.tree.map(lambda g, m: g - m, updates, state.mu)
prediction_error = otu.tree_sub(updates, mu)
nu = otu.tree_update_moment_per_elem_norm(prediction_error, state.nu, b2, 2)
nu = jax.tree.map(lambda v: v + eps_root, nu)
count_inc = numerics.safe_increment(state.count)
Expand Down

0 comments on commit 723e6bf

Please sign in to comment.