forked from MilesCranmer/lagrangian_nns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lnn.py
105 lines (86 loc) · 3.42 KB
/
lnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Generalized Lagrangian Networks | 2020
# Miles Cranmer, Sam Greydanus, Stephan Hoyer (...)
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from functools import partial
# unconstrained equation of motion
def unconstrained_eom(model, state, t=None):
q, q_t = jnp.split(state, 2)
return model(q, q_t)
# lagrangian equation of motion
def lagrangian_eom(lagrangian, state, t=None):
q, q_t = jnp.split(state, 2)
#Note: the following line assumes q is an angle. Delete it for problems other than double pendulum.
q = q % (2*jnp.pi)
q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
@ (jax.grad(lagrangian, 0)(q, q_t)
- jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
dt = 1e-1
return dt*jnp.concatenate([q_t, q_tt])
def raw_lagrangian_eom(lagrangian, state, t=None):
q, q_t = jnp.split(state, 2)
q = q % (2*jnp.pi)
q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
@ (jax.grad(lagrangian, 0)(q, q_t)
- jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
return jnp.concatenate([q_t, q_tt])
def lagrangian_eom_rk4(lagrangian, state, n_updates, Dt=1e-1, t=None):
@jax.jit
def cur_fnc(state):
q, q_t = jnp.split(state, 2)
q = q % (2*jnp.pi)
q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
@ (jax.grad(lagrangian, 0)(q, q_t)
- jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
return jnp.concatenate([q_t, q_tt])
@jax.jit
def get_update(update):
dt = Dt/n_updates
cstate = state + update
k1 = dt*cur_fnc(cstate)
k2 = dt*cur_fnc(cstate + k1/2)
k3 = dt*cur_fnc(cstate + k2/2)
k4 = dt*cur_fnc(cstate + k3)
return update + 1.0/6.0 * (k1 + 2*k2 + 2*k3 + k4)
update = 0
for _ in range(n_updates):
update = get_update(update)
return update
def solve_dynamics(dynamics_fn, initial_state, is_lagrangian=True, **kwargs):
eom = lagrangian_eom if is_lagrangian else unconstrained_eom
# We currently run odeint on CPUs only, because its cost is dominated by
# control flow, which is slow on GPUs.
@partial(jax.jit, backend='cpu')
def f(initial_state):
return odeint(partial(eom, dynamics_fn), initial_state, **kwargs)
return f(initial_state)
def custom_init(init_params, seed=0):
"""Do an optimized LNN initialization for a simple uniform-width MLP"""
import numpy as np
new_params = []
rng = jax.random.PRNGKey(seed)
i = 0
number_layers = len([0 for l1 in init_params if len(l1) != 0])
for l1 in init_params:
if (len(l1)) == 0: new_params.append(()); continue
new_l1 = []
for l2 in l1:
if len(l2.shape) == 1:
#Zero init biases
new_l1.append(jnp.zeros_like(l2))
else:
n = max(l2.shape)
first = int(i == 0)
last = int(i == number_layers - 1)
mid = int((i != 0) * (i != number_layers - 1))
mid *= i
std = 1.0/np.sqrt(n)
std *= 2.2*first + 0.58*mid + n*last
if std == 0:
raise NotImplementedError("Wrong dimensions for MLP")
new_l1.append(jax.random.normal(rng, l2.shape)*std)
rng += 1
i += 1
new_params.append(new_l1)
return new_params