Skip to content

Commit

Permalink
make full_output case also differentiable, increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Oct 19, 2024
1 parent e10b2e4 commit c40eaf1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
4 changes: 2 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def bodyfun(state):
xk1, fk1 = backtrack(xk1, fk1, d)
return xk1, fk1, k1 + 1

state = guess, res(guess), 0
state = guess, res(guess), 0.0
state = jax.lax.while_loop(condfun, bodyfun, state)
if full_output:
return state[0], state[1:]
Expand Down Expand Up @@ -401,7 +401,7 @@ def bodyfun(state):
state = (
jnp.atleast_1d(jnp.asarray(guess)),
jnp.atleast_1d(resfun(guess)),
0,
0.0,
)
state = jax.lax.while_loop(condfun, bodyfun, state)
if full_output:
Expand Down
29 changes: 26 additions & 3 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_map_coordinates_derivative():
with pytest.warns(UserWarning, match="Reducing radial"):
eq.change_resolution(3, 3, 0, 6, 6, 0)
inbasis = ["alpha", "phi", "rho"]
outbasis = ["rho", "theta_PEST", "zeta"]

rho = np.linspace(0.01, 0.99, 20)
theta = np.linspace(0, np.pi, 20, endpoint=False)
Expand All @@ -104,12 +103,35 @@ def test_map_coordinates_derivative():
def foo(params, in_coords):
out = eq.map_coordinates(
in_coords,
inbasis,
outbasis,
("rho", "alpha", "zeta"), # for this test, zeta==phi
("rho", "theta_PEST", "zeta"),
np.array([rho, theta, zeta]).T,
params,
period=(2 * np.pi, 2 * np.pi, np.inf),
maxiter=40,
)
return out

J1 = jax.jit(jax.jacfwd(foo))(eq.params_dict, in_coords)
J2 = jax.jit(jax.jacrev(foo))(eq.params_dict, in_coords)
for j1, j2 in zip(J1.values(), J2.values()):
assert ~np.any(np.isnan(j1))
assert ~np.any(np.isnan(j2))
np.testing.assert_allclose(j1, j2)

# Check map_coordinates with full_output is still runs without errors
# this time _map_clebsch_coordinates is called inside map_coordinates
@jax.jit
def foo(params, in_coords):
out, _ = eq.map_coordinates(
in_coords,
("rho", "alpha", "zeta"), # for this test, zeta==phi
("rho", "theta", "zeta"),
np.array([rho, theta, zeta]).T,
params,
period=(2 * np.pi, 2 * np.pi, np.inf),
maxiter=40,
full_output=True,
)
return out

Expand All @@ -129,6 +151,7 @@ def foo(params, in_coords):
flux_coords = nodes.copy()
flux_coords[:, 1] += coords["lambda"]

# this will call _map_PEST_coordinates inside map_coordinates
@jax.jit
def bar(L_lmn):
geom_coords = eq.map_coordinates(
Expand Down

0 comments on commit c40eaf1

Please sign in to comment.