Skip to content

Commit 5ace529

Browse files
committed
Add Adam optimizer
1 parent d03d54b commit 5ace529

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

blinx/optimizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from collections import namedtuple
22

33
import jax
4+
from optax._src.transform import scale_by_adam
45

56
Optimizer = namedtuple("Optimizer", ["init", "step"])
67

78

89
def create_optimizer(value_grad_func, hyper_parameters):
10+
"""A simple gradient ascent optimizer."""
11+
912
step_sizes = hyper_parameters.step_sizes
1013

1114
def init(parameters):
@@ -24,3 +27,42 @@ def step(trace, parameters, opt_state):
2427
return parameters, value, opt_state, gradients
2528

2629
return Optimizer(init, step)
30+
31+
32+
def create_adam_optimizer(
33+
value_grad_func,
34+
hyper_parameters,
35+
b1=0.9,
36+
b2=0.999,
37+
eps=1e-8,
38+
eps_root=0.0,
39+
mu_dtype=None):
40+
"""The Adam optimizer for maximization of the given function."""
41+
42+
step_sizes = hyper_parameters.step_sizes
43+
44+
adam_transform = scale_by_adam(
45+
b1=b1, b2=b2,
46+
eps=eps, eps_root=eps_root,
47+
mu_dtype=mu_dtype)
48+
49+
def init(parameters):
50+
return adam_transform.init(parameters)
51+
52+
def step(trace, parameters, opt_state):
53+
54+
# get value and gradient
55+
value, gradients = value_grad_func(trace, parameters)
56+
57+
# Adam update
58+
updates, opt_state = adam_transform.update(gradients, opt_state)
59+
60+
# update parameters with step size
61+
parameters = jax.tree_util.tree_map(
62+
lambda p, s, u: p + s * u, parameters, step_sizes, updates
63+
)
64+
65+
# return updated parameters, current value, and optimizer state
66+
return parameters, value, opt_state, gradients
67+
68+
return Optimizer(init, step)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ authors = [
1515
{ name = 'Alex Hillsley', email = 'hillsleya@janelia.hhmi.org' },
1616
]
1717
dynamic = ["version"]
18-
dependencies = ["numpy", "scipy", "jax", "jaxlib", "scikit-image", "tqdm"]
18+
dependencies = ["numpy", "scipy", "jax", "jaxlib", "optax", "scikit-image", "tqdm"]
1919

2020
[project.optional-dependencies]
2121
dev = ["pre-commit", "pytest", "pytest-cov", "ruff", "twine", "build"]

0 commit comments

Comments
 (0)