diff --git a/desc/backend.py b/desc/backend.py index 5fb6a49bd..d68fc5344 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -273,7 +273,7 @@ def bodyfun(state): xk1, fk1 = backtrack(xk1, fk1, d) return xk1, fk1, k1 + 1 - state = guess, res(guess), 0 + state = guess, res(guess), 0.0 state = jax.lax.while_loop(condfun, bodyfun, state) if full_output: return state[0], state[1:] @@ -401,7 +401,7 @@ def bodyfun(state): state = ( jnp.atleast_1d(jnp.asarray(guess)), jnp.atleast_1d(resfun(guess)), - 0, + 0.0, ) state = jax.lax.while_loop(condfun, bodyfun, state) if full_output: diff --git a/tests/test_equilibrium.py b/tests/test_equilibrium.py index 175daa8f1..cfab9107d 100644 --- a/tests/test_equilibrium.py +++ b/tests/test_equilibrium.py @@ -88,7 +88,6 @@ def test_map_coordinates_derivative(): with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(3, 3, 0, 6, 6, 0) inbasis = ["alpha", "phi", "rho"] - outbasis = ["rho", "theta_PEST", "zeta"] rho = np.linspace(0.01, 0.99, 20) theta = np.linspace(0, np.pi, 20, endpoint=False) @@ -104,12 +103,35 @@ def test_map_coordinates_derivative(): def foo(params, in_coords): out = eq.map_coordinates( in_coords, - inbasis, - outbasis, + ("rho", "alpha", "zeta"), # for this test, zeta==phi + ("rho", "theta_PEST", "zeta"), + np.array([rho, theta, zeta]).T, + params, + period=(2 * np.pi, 2 * np.pi, np.inf), + maxiter=40, + ) + return out + + J1 = jax.jit(jax.jacfwd(foo))(eq.params_dict, in_coords) + J2 = jax.jit(jax.jacrev(foo))(eq.params_dict, in_coords) + for j1, j2 in zip(J1.values(), J2.values()): + assert ~np.any(np.isnan(j1)) + assert ~np.any(np.isnan(j2)) + np.testing.assert_allclose(j1, j2) + + # Check map_coordinates with full_output is still runs without errors + # this time _map_clebsch_coordinates is called inside map_coordinates + @jax.jit + def foo(params, in_coords): + out, _ = eq.map_coordinates( + in_coords, + ("rho", "alpha", "zeta"), # for this test, zeta==phi + ("rho", "theta", "zeta"), np.array([rho, theta, zeta]).T, params, period=(2 * np.pi, 2 * np.pi, np.inf), maxiter=40, + full_output=True, ) return out @@ -129,6 +151,7 @@ def foo(params, in_coords): flux_coords = nodes.copy() flux_coords[:, 1] += coords["lambda"] + # this will call _map_PEST_coordinates inside map_coordinates @jax.jit def bar(L_lmn): geom_coords = eq.map_coordinates(