Skip to content

Commit

Permalink
Merge branch 'master' into integrate_on_boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Sep 19, 2024
2 parents a93b98b + 75eafcc commit 66b402a
Show file tree
Hide file tree
Showing 16 changed files with 329 additions and 169 deletions.
40 changes: 40 additions & 0 deletions desc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,46 @@ def flip_helicity(eq):
return eq


def flip_theta(eq):
"""Change the gauge freedom of the poloidal angle of an Equilibrium.
Equivalent to redefining theta_new = theta_old + π
Parameters
----------
eq : Equilibrium or iterable of Equilibrium
Equilibria to redefine the poloidal angle of.
Returns
-------
eq : Equilibrium or iterable of Equilibrium
Same as input, but with the poloidal angle redefined.
"""
# maybe it's iterable:
if hasattr(eq, "__len__"):
for e in eq:
flip_theta(e)
return eq

rone = np.ones_like(eq.R_lmn)
rone[eq.R_basis.modes[:, 1] % 2 == 1] *= -1
eq.R_lmn *= rone

zone = np.ones_like(eq.Z_lmn)
zone[eq.Z_basis.modes[:, 1] % 2 == 1] *= -1
eq.Z_lmn *= zone

lone = np.ones_like(eq.L_lmn)
lone[eq.L_basis.modes[:, 1] % 2 == 1] *= -1
eq.L_lmn *= lone

eq.axis = eq.get_axis()
eq.surface = eq.get_surface_at(rho=1)

return eq


def rescale(
eq, L=("R0", None), B=("B0", None), scale_pressure=True, copy=False, verbose=0
):
Expand Down
6 changes: 3 additions & 3 deletions desc/compute/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
label="\\alpha",
units="~",
units_long="None",
description="Field line label, defined on [0, 2pi)",
description="Field line label",
dim=1,
params=[],
transforms={},
Expand All @@ -1503,7 +1503,7 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
data=["theta_PEST", "phi", "iota"],
)
def _alpha(params, transforms, profiles, data, **kwargs):
data["alpha"] = (data["theta_PEST"] - data["iota"] * data["phi"]) % (2 * jnp.pi)
data["alpha"] = data["theta_PEST"] - data["iota"] * data["phi"]
return data


Expand Down Expand Up @@ -3077,7 +3077,7 @@ def _theta(params, transforms, profiles, data, **kwargs):
data=["theta", "lambda"],
)
def _theta_PEST(params, transforms, profiles, data, **kwargs):
data["theta_PEST"] = (data["theta"] + data["lambda"]) % (2 * jnp.pi)
data["theta_PEST"] = data["theta"] + data["lambda"]
return data


Expand Down
131 changes: 71 additions & 60 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def _periodic(x, period):
return jnp.where(jnp.isfinite(period), x % period, x)


def _fixup_residual(r, period):
r = _periodic(r, period)
# r should be between -period and period
return jnp.where((r > period / 2) & jnp.isfinite(period), -period + r, r)


def map_coordinates( # noqa: C901
eq,
coords,
Expand Down Expand Up @@ -87,9 +93,9 @@ def map_coordinates( # noqa: C901
ValueError,
f"tol must be a positive float, got {tol}",
)
params = setdefault(params, eq.params_dict)
inbasis = tuple(inbasis)
outbasis = tuple(outbasis)
params = setdefault(params, eq.params_dict)

basis_derivs = tuple(f"{X}_{d}" for X in inbasis for d in ("r", "t", "z"))
for key in basis_derivs:
Expand All @@ -111,25 +117,27 @@ def map_coordinates( # noqa: C901
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params)
iota = profiles["iota"].compute(Grid(coords, sort=False, jitable=True))
return _map_clebsch_coordinates(
coords,
iota,
params["L_lmn"],
eq.L_basis,
guess[:, 1] if guess is not None else None,
tol,
maxiter,
full_output,
coords=coords,
iota=iota,
L_lmn=params["L_lmn"],
L_basis=eq.L_basis,
guess=guess[:, 1] if guess is not None else None,
period=period[1] if period is not None else np.inf,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
if inbasis == ("rho", "theta_PEST", "zeta"):
return _map_PEST_coordinates(
coords,
params["L_lmn"],
eq.L_basis,
guess[:, 1] if guess is not None else None,
tol,
maxiter,
full_output,
coords=coords,
L_lmn=params["L_lmn"],
L_basis=eq.L_basis,
guess=guess[:, 1] if guess is not None else None,
period=period[1] if period is not None else np.inf,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)

Expand All @@ -139,7 +147,6 @@ def map_coordinates( # noqa: C901
params["i_l"] = profiles["iota"].params

rhomin = kwargs.pop("rhomin", tol / 10)
warnif(period is None, msg="Assuming no periodicity.")
period = np.asarray(setdefault(period, (np.inf, np.inf, np.inf)))
coords = _periodic(coords, period)

Expand All @@ -165,8 +172,7 @@ def compute(y, basis):
@jit
def residual(y, coords):
xk = compute(y, inbasis)
r = _periodic(xk, period) - _periodic(coords, period)
return jnp.where((r > period / 2) & jnp.isfinite(period), -period + r, r)
return _fixup_residual(xk - coords, period)

@jit
def jac(y, coords):
Expand Down Expand Up @@ -212,7 +218,6 @@ def fixup(y, *args):
yk, (res, niter) = vecroot(yk, coords)

out = compute(yk, outbasis)

if full_output:
return out, (res, niter)
return out
Expand Down Expand Up @@ -253,7 +258,7 @@ def _initial_guess_heuristic(yk, coords, inbasis, eq, profiles):
zero = jnp.zeros_like(rho)
grid = Grid(nodes=jnp.column_stack([rho, zero, zero]), sort=False, jitable=True)
iota = profiles["iota"].compute(grid)
theta = (alpha + iota * zeta) % (2 * jnp.pi)
theta = alpha + iota * zeta

yk = jnp.column_stack([rho, theta, zeta])
return yk
Expand Down Expand Up @@ -284,6 +289,7 @@ def _map_PEST_coordinates(
L_lmn,
L_basis,
guess,
period=np.inf,
tol=1e-6,
maxiter=30,
full_output=False,
Expand All @@ -304,6 +310,9 @@ def _map_PEST_coordinates(
guess : jnp.ndarray
Shape (k, ).
Optional initial guess for the computational coordinates.
period : float
Assumed periodicity for ϑ.
Use ``np.inf`` to denote no periodicity.
tol : float
Stopping tolerance.
maxiter : int
Expand All @@ -325,36 +334,25 @@ def _map_PEST_coordinates(
Only returned if ``full_output`` is True.
"""
rho, theta_PEST, zeta = coords.T
theta_PEST = theta_PEST % (2 * np.pi)
# Assume λ=0 for initial guess.
guess = setdefault(guess, theta_PEST)
# noqa: D202

# Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0.
def rootfun(theta_DESC, theta_PEST, rho, zeta):
nodes = jnp.array(
[rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2
)
def rootfun(theta, theta_PEST, rho, zeta):
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A = L_basis.evaluate(nodes)
lmbda = A @ L_lmn
theta_PEST_k = (theta_DESC + lmbda) % (2 * np.pi)
r = theta_PEST_k - theta_PEST
# r should be between -pi and pi
r = jnp.where(r > np.pi, r - 2 * np.pi, r)
r = jnp.where(r < -np.pi, r + 2 * np.pi, r)
return r.squeeze()

def jacfun(theta_DESC, theta_PEST, rho, zeta):
# Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ.
nodes = jnp.array(
[rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2
)
theta_PEST_k = theta + lmbda
return _fixup_residual(theta_PEST_k - theta_PEST, period).squeeze()

def jacfun(theta, theta_PEST, rho, zeta):
# Valid everywhere except θ such that θ+λ = k period where k ∈ ℤ.
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A1 = L_basis.evaluate(nodes, (0, 1, 0))
lmbda_t = jnp.dot(A1, L_lmn)
return 1 + lmbda_t.squeeze()

def fixup(x, *args):
return x % (2 * np.pi)
return _periodic(x, period)

vecroot = jit(
vmap(
Expand All @@ -370,10 +368,15 @@ def fixup(x, *args):
)
)
)
theta_DESC, (res, niter) = vecroot(guess, theta_PEST, rho, zeta)

out = jnp.column_stack([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta])

rho, theta_PEST, zeta = coords.T
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
return out
Expand All @@ -386,6 +389,7 @@ def _map_clebsch_coordinates(
L_lmn,
L_basis,
guess=None,
period=np.inf,
tol=1e-6,
maxiter=30,
full_output=False,
Expand All @@ -409,6 +413,9 @@ def _map_clebsch_coordinates(
guess : jnp.ndarray
Shape (k, ).
Optional initial guess for the computational coordinates.
period : float
Assumed periodicity for α.
Use ``np.inf`` to denote no periodicity.
tol : float
Stopping tolerance.
maxiter : int
Expand All @@ -430,32 +437,25 @@ def _map_clebsch_coordinates(
Only returned if ``full_output`` is True.
"""
rho, alpha, zeta = coords.T
if guess is None:
# Assume λ=0 for initial guess.
guess = (alpha + iota * zeta) % (2 * np.pi)
# noqa: D202

# Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0.
def rootfun(theta, alpha, rho, zeta, iota):
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A = L_basis.evaluate(nodes)
lmbda = A @ L_lmn
alpha_k = theta + lmbda - iota * zeta
r = (alpha_k - alpha) % (2 * np.pi)
# r should be between -pi and pi
r = jnp.where(r > np.pi, r - 2 * np.pi, r)
r = jnp.where(r < -np.pi, r + 2 * np.pi, r)
return r.squeeze()
return _fixup_residual(alpha_k - alpha, period).squeeze()

def jacfun(theta, alpha, rho, zeta, iota):
# Valid everywhere except θ such that θ+λ = k where k ∈ ℤ.
# Valid everywhere except θ such that θ+λ = k period where k ∈ ℤ.
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A1 = L_basis.evaluate(nodes, (0, 1, 0))
lmbda_t = jnp.dot(A1, L_lmn)
return 1 + lmbda_t.squeeze()

def fixup(x, *args):
return x % (2 * np.pi)
return _periodic(x, period)

vecroot = jit(
vmap(
Expand All @@ -471,9 +471,13 @@ def fixup(x, *args):
)
)
)
rho, alpha, zeta = coords.T
if guess is None:
# Assume λ=0 for default initial guess.
guess = alpha + iota * zeta
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])

out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
return out
Expand Down Expand Up @@ -662,7 +666,14 @@ def to_sfl(


def get_rtz_grid(
eq, radial, poloidal, toroidal, coordinates, period, jitable=True, **kwargs
eq,
radial,
poloidal,
toroidal,
coordinates,
period=(np.inf, np.inf, np.inf),
jitable=True,
**kwargs,
):
"""Return DESC grid in rtz (rho, theta, zeta) coordinates from given coordinates.
Expand All @@ -685,7 +696,7 @@ def get_rtz_grid(
rvp : rho, theta_PEST, phi
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for functions of the given coordinates.
Assumed periodicity of the given coordinates.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Expand Down
2 changes: 1 addition & 1 deletion desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def compute_theta_coords(
)
return map_coordinates(
self,
flux_coords,
coords=flux_coords,
inbasis=("rho", "theta_PEST", "zeta"),
outbasis=("rho", "theta", "zeta"),
params=self.params_dict if L_lmn is None else {"L_lmn": L_lmn},
Expand Down
2 changes: 2 additions & 0 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,8 @@ def _create_nodes( # noqa: C901
"""
self._NFP = check_posint(NFP, "NFP", False)
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
# TODO:
# https://github.com/PlasmaControl/DESC/pull/1204#pullrequestreview-2246771337
axis = bool(axis)
endpoint = bool(endpoint)
theta_period = self.period[1]
Expand Down
Loading

0 comments on commit 66b402a

Please sign in to comment.