1
1
from collections import namedtuple
2
2
3
3
import jax
4
+ from optax ._src .transform import scale_by_adam
4
5
5
6
Optimizer = namedtuple ("Optimizer" , ["init" , "step" ])
6
7
7
8
8
9
def create_optimizer (value_grad_func , hyper_parameters ):
10
+ """A simple gradient ascent optimizer."""
11
+
9
12
step_sizes = hyper_parameters .step_sizes
10
13
11
14
def init (parameters ):
@@ -24,3 +27,42 @@ def step(trace, parameters, opt_state):
24
27
return parameters , value , opt_state , gradients
25
28
26
29
return Optimizer (init , step )
30
+
31
+
32
+ def create_adam_optimizer (
33
+ value_grad_func ,
34
+ hyper_parameters ,
35
+ b1 = 0.9 ,
36
+ b2 = 0.999 ,
37
+ eps = 1e-8 ,
38
+ eps_root = 0.0 ,
39
+ mu_dtype = None ):
40
+ """The Adam optimizer for maximization of the given function."""
41
+
42
+ step_sizes = hyper_parameters .step_sizes
43
+
44
+ adam_transform = scale_by_adam (
45
+ b1 = b1 , b2 = b2 ,
46
+ eps = eps , eps_root = eps_root ,
47
+ mu_dtype = mu_dtype )
48
+
49
+ def init (parameters ):
50
+ return adam_transform .init (parameters )
51
+
52
+ def step (trace , parameters , opt_state ):
53
+
54
+ # get value and gradient
55
+ value , gradients = value_grad_func (trace , parameters )
56
+
57
+ # Adam update
58
+ updates , opt_state = adam_transform .update (gradients , opt_state )
59
+
60
+ # update parameters with step size
61
+ parameters = jax .tree_util .tree_map (
62
+ lambda p , s , u : p + s * u , parameters , step_sizes , updates
63
+ )
64
+
65
+ # return updated parameters, current value, and optimizer state
66
+ return parameters , value , opt_state , gradients
67
+
68
+ return Optimizer (init , step )
0 commit comments