From e10b2e4caeb8217124de75981a4a9e10a93dab02 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Sat, 19 Oct 2024 16:41:59 -0400 Subject: [PATCH] fix constant_offset_surface function --- desc/equilibrium/coords.py | 2 ++ desc/geometry/surface.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index b4a90abad..d217a51f0 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -367,6 +367,7 @@ def fixup(x, *args): fixup=fixup, tol=tol, maxiter=maxiter, + full_output=full_output, **kwargs, ) ) @@ -479,6 +480,7 @@ def fixup(x, *args): fixup=fixup, tol=tol, maxiter=maxiter, + full_output=full_output, **kwargs, ) ) diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 817eee4a1..ffeba95ae 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -741,10 +741,19 @@ def fun_jax(zeta_hat, theta, zeta): n, r, r_offset = n_and_r_jax(nodes) return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta - vecroot = jit(vmap(lambda x0, *p: root_scalar(fun_jax, x0, jac=None, args=p))) - zetas, (res, niter) = vecroot( - grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2] + vecroot = jit( + vmap( + lambda x0, *p: root_scalar( + fun_jax, x0, jac=None, args=p, full_output=full_output + ) + ) ) + if full_output: + zetas, (res, niter) = vecroot( + grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2] + ) + else: + zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]) zetas = np.asarray(zetas) nodes = np.vstack((np.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T