Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize boozer transform over multiple surfaces #1197

Merged
merged 15 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 87 additions & 27 deletions desc/compute/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,38 @@
@register_compute_fun(
name="B_theta_mn",
label="B_{\\theta, m, n}",
units="T \\cdot m}",
units="T \\cdot m",
units_long="Tesla * meters",
description="Fourier coefficients for covariant poloidal component of "
"magnetic field.",
dim=1,
params=[],
transforms={"B": [[0, 0, 0]]},
transforms={"B": [[0, 0, 0]], "grid": []},
profiles=[],
coordinates="rtz",
data=["B_theta"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
resolution_requirement="tz",
)
def _B_theta_mn(params, transforms, profiles, data, **kwargs):
data["B_theta_mn"] = transforms["B"].fit(data["B_theta"])
B_theta = transforms["grid"].meshgrid_reshape(data["B_theta"], "rtz")

def fitfun(x):
return transforms["B"].fit(x.flatten(order="F"))

B_theta_mn = vmap(fitfun)(B_theta)
# modes stored as shape(rho, mn) flattened
data["B_theta_mn"] = B_theta_mn.flatten()
return data


# TODO: do math to change definition of nu so that we can just use B_zeta_mn here
@register_compute_fun(
name="B_phi_mn",
label="B_{\\phi, m, n}",
units="T \\cdot m}",
units="T \\cdot m",
units_long="Tesla * meters",
description="Fourier coefficients for covariant toroidal component of "
"magnetic field in (ρ,θ,ϕ) coordinates.",
Expand All @@ -53,13 +61,21 @@ def _B_theta_mn(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["B_phi|r,t"],
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
aliases="B_zeta_mn", # TODO: remove when phi != zeta
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
def _B_phi_mn(params, transforms, profiles, data, **kwargs):
data["B_phi_mn"] = transforms["B"].fit(data["B_phi|r,t"])
B_phi = transforms["grid"].meshgrid_reshape(data["B_phi|r,t"], "rtz")

def fitfun(x):
return transforms["B"].fit(x.flatten(order="F"))

B_zeta_mn = vmap(fitfun)(B_phi)
# modes stored as shape(rho, mn) flattened
data["B_phi_mn"] = B_zeta_mn.flatten()
return data


Expand All @@ -72,15 +88,16 @@ def _B_phi_mn(params, transforms, profiles, data, **kwargs):
+ "Boozer Coordinates'",
dim=1,
params=[],
transforms={"w": [[0, 0, 0]], "B": [[0, 0, 0]]},
transforms={"w": [[0, 0, 0]], "B": [[0, 0, 0]], "grid": []},
profiles=[],
coordinates="rtz",
data=["B_theta_mn", "B_phi_mn"],
grid_requirement={"is_meshgrid": True},
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
def _w_mn(params, transforms, profiles, data, **kwargs):
w_mn = jnp.zeros((transforms["w"].basis.num_modes,))
w_mn = jnp.zeros((transforms["grid"].num_rho, transforms["w"].basis.num_modes))
Bm = transforms["B"].basis.modes[:, 1]
Bn = transforms["B"].basis.modes[:, 2]
wm = transforms["w"].basis.modes[:, 1]
Expand All @@ -89,15 +106,19 @@ def _w_mn(params, transforms, profiles, data, **kwargs):
mask_t = (Bm[:, None] == -wm) & (Bn[:, None] == wn) & (wm != 0)
mask_z = (Bm[:, None] == wm) & (Bn[:, None] == -wn) & (wm == 0) & (wn != 0)

num_t = (mask_t @ sign(wn)) * data["B_theta_mn"]
num_t = (mask_t @ sign(wn)) * data["B_theta_mn"].reshape(
(transforms["grid"].num_rho, -1)
)
den_t = mask_t @ jnp.abs(wm)
num_z = (mask_z @ sign(wm)) * data["B_phi_mn"]
num_z = (mask_z @ sign(wm)) * data["B_phi_mn"].reshape(
(transforms["grid"].num_rho, -1)
)
den_z = mask_z @ jnp.abs(NFP * wn)

w_mn = jnp.where(mask_t.any(axis=0), mask_t.T @ safediv(num_t, den_t), w_mn)
w_mn = jnp.where(mask_z.any(axis=0), mask_z.T @ safediv(num_z, den_z), w_mn)
w_mn = jnp.where(mask_t.any(axis=0), (mask_t.T @ safediv(num_t, den_t).T).T, w_mn)
w_mn = jnp.where(mask_z.any(axis=0), (mask_z.T @ safediv(num_z, den_z).T).T, w_mn)

data["w_Boozer_mn"] = w_mn
data["w_Boozer_mn"] = w_mn.flatten()
return data


Expand All @@ -110,16 +131,22 @@ def _w_mn(params, transforms, profiles, data, **kwargs):
+ "'Transformation from VMEC to Boozer Coordinates'",
dim=1,
params=[],
transforms={"w": [[0, 0, 0]]},
transforms={"w": [[0, 0, 0]], "grid": []},
profiles=[],
coordinates="rtz",
data=["w_Boozer_mn"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
def _w(params, transforms, profiles, data, **kwargs):
data["w_Boozer"] = transforms["w"].transform(data["w_Boozer_mn"])
grid = transforms["grid"]
w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1))
w = vmap(transforms["w"].transform)(w_mn) # shape(rho, theta*zeta)
w = w.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F")
w = jnp.moveaxis(w, 0, 1)
data["w_Boozer"] = w.flatten(order="F")
return data


Expand All @@ -132,16 +159,24 @@ def _w(params, transforms, profiles, data, **kwargs):
+ "'Transformation from VMEC to Boozer Coordinates', poloidal derivative",
dim=1,
params=[],
transforms={"w": [[0, 1, 0]]},
transforms={"w": [[0, 1, 0]], "grid": []},
profiles=[],
coordinates="rtz",
data=["w_Boozer_mn"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
def _w_t(params, transforms, profiles, data, **kwargs):
data["w_Boozer_t"] = transforms["w"].transform(data["w_Boozer_mn"], dt=1)
grid = transforms["grid"]
w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1))
# need to close over dt which can't be vmapped
fun = lambda x: transforms["w"].transform(x, dt=1)
w_t = vmap(fun)(w_mn) # shape(rho, theta*zeta)
w_t = w_t.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F")
w_t = jnp.moveaxis(w_t, 0, 1)
data["w_Boozer_t"] = w_t.flatten(order="F")
return data


Expand All @@ -154,16 +189,24 @@ def _w_t(params, transforms, profiles, data, **kwargs):
+ "'Transformation from VMEC to Boozer Coordinates', toroidal derivative",
dim=1,
params=[],
transforms={"w": [[0, 0, 1]]},
transforms={"w": [[0, 0, 1]], "grid": []},
profiles=[],
coordinates="rtz",
data=["w_Boozer_mn"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
def _w_z(params, transforms, profiles, data, **kwargs):
data["w_Boozer_z"] = transforms["w"].transform(data["w_Boozer_mn"], dz=1)
grid = transforms["grid"]
w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1))
# need to close over dz which can't be vmapped
fun = lambda x: transforms["w"].transform(x, dz=1)
w_z = vmap(fun)(w_mn) # shape(rho, theta*zeta)
w_z = w_z.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F")
w_z = jnp.moveaxis(w_z, 0, 1)
data["w_Boozer_z"] = w_z.flatten(order="F")
return data


Expand Down Expand Up @@ -290,21 +333,38 @@ def _sqrtg_B(params, transforms, profiles, data, **kwargs):
description="Boozer harmonics of magnetic field",
dim=1,
params=[],
transforms={"B": [[0, 0, 0]]},
transforms={"B": [[0, 0, 0]], "grid": []},
profiles=[],
coordinates="rtz",
data=["sqrt(g)_B", "|B|", "rho", "theta_B", "zeta_B"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
def _B_mn(params, transforms, profiles, data, **kwargs):
nodes = jnp.array([data["rho"], data["theta_B"], data["zeta_B"]]).T
norm = 2 ** (3 - jnp.sum((transforms["B"].basis.modes == 0), axis=1))
data["|B|_mn"] = (
norm # 1 if m=n=0, 2 if m=0 or n=0, 4 if m!=0 and n!=0
* (transforms["B"].basis.evaluate(nodes).T @ (data["sqrt(g)_B"] * data["|B|"]))
/ transforms["B"].grid.num_nodes
grid = transforms["grid"]

def fun(rho, theta_B, zeta_B, sqrtg_B, B):
# this fits Boozer modes on a single surface
nodes = jnp.array([rho, theta_B, zeta_B]).T
B_mn = (
norm # 1 if m=n=0, 2 if m=0 or n=0, 4 if m!=0 and n!=0
* (transforms["B"].basis.evaluate(nodes).T @ (sqrtg_B * B))
/ transforms["B"].grid.num_nodes
)
return B_mn

def reshape(x):
return grid.meshgrid_reshape(x, "rtz").reshape((grid.num_rho, -1))

rho, theta_B, zeta_B, sqrtg_B, B = map(
reshape,
(data["rho"], data["theta_B"], data["zeta_B"], data["sqrt(g)_B"], data["|B|"]),
)
B_mn = vmap(fun)(rho, theta_B, zeta_B, sqrtg_B, B)
data["|B|_mn"] = B_mn.flatten()
return data


Expand Down
9 changes: 9 additions & 0 deletions desc/compute/data_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def register_compute_fun( # noqa: C901
aliases=None,
parameterization="desc.equilibrium.equilibrium.Equilibrium",
resolution_requirement="",
grid_requirement=None,
source_grid_requirement=None,
**kwargs,
):
Expand Down Expand Up @@ -110,6 +111,11 @@ def register_compute_fun( # noqa: C901
If the computation simply performs pointwise operations, instead of a
reduction (such as integration) over a coordinate, then an empty string may
be used to indicate no requirements.
grid_requirement : dict
Attributes of the grid that the compute function requires.
Also assumes dependencies were computed on such a grid.
As an example, quantities that require tensor product grids over 2 or more
coordinates may specify ``grid_requirement={"is_meshgrid": True}``.
source_grid_requirement : dict
Attributes of the source grid that the compute function requires.
Also assumes dependencies were computed on such a grid.
Expand All @@ -130,6 +136,8 @@ def register_compute_fun( # noqa: C901
aliases = []
if source_grid_requirement is None:
source_grid_requirement = {}
if grid_requirement is None:
grid_requirement = {}
if not isinstance(parameterization, (tuple, list)):
parameterization = [parameterization]
if not isinstance(aliases, (tuple, list)):
Expand Down Expand Up @@ -168,6 +176,7 @@ def _decorator(func):
"dependencies": deps,
"aliases": aliases,
"resolution_requirement": resolution_requirement,
"grid_requirement": grid_requirement,
"source_grid_requirement": source_grid_requirement,
}
for p in parameterization:
Expand Down
32 changes: 29 additions & 3 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def _parse_parameterization(p):
return module + "." + klass.__qualname__


def compute(parameterization, names, params, transforms, profiles, data=None, **kwargs):
def compute( # noqa: C901
parameterization, names, params, transforms, profiles, data=None, **kwargs
):
"""Compute the quantity given by name on grid.

Parameters
Expand Down Expand Up @@ -88,6 +90,15 @@ def compute(parameterization, names, params, transforms, profiles, data=None, **
if "grid" in transforms:

def check_fun(name):
reqs = data_index[p][name]["grid_requirement"]
for req in reqs:
errorif(
not hasattr(transforms["grid"], req)
or reqs[req] != getattr(transforms["grid"], req),
AttributeError,
f"Expected grid with '{req}:{reqs[req]}' to compute {name}.",
)

reqs = data_index[p][name]["source_grid_requirement"]
errorif(
reqs and not hasattr(transforms["grid"], "source_grid"),
Expand Down Expand Up @@ -517,6 +528,7 @@ def get_transforms(

"""
from desc.basis import DoubleFourierSeries
from desc.grid import LinearGrid
from desc.transform import Transform

method = "jitable" if jitable or kwargs.get("method") == "jitable" else "auto"
Expand Down Expand Up @@ -556,8 +568,15 @@ def get_transforms(
)
transforms[c] = c_transform
elif c == "B": # used for Boozer transform
# assume grid is a meshgrid but only care about a single surface
if grid.num_rho > 1:
theta = grid.nodes[grid.unique_theta_idx, 1]
zeta = grid.nodes[grid.unique_zeta_idx, 2]
grid_B = LinearGrid(theta=theta, zeta=zeta, NFP=grid.NFP, sym=grid.sym)
else:
grid_B = grid
transforms["B"] = Transform(
grid,
grid_B,
DoubleFourierSeries(
M=kwargs.get("M_booz", 2 * obj.M),
N=kwargs.get("N_booz", 2 * obj.N),
Expand All @@ -570,8 +589,15 @@ def get_transforms(
method=method,
)
elif c == "w": # used for Boozer transform
# assume grid is a meshgrid but only care about a single surface
if grid.num_rho > 1:
theta = grid.nodes[grid.unique_theta_idx, 1]
zeta = grid.nodes[grid.unique_zeta_idx, 2]
grid_w = LinearGrid(theta=theta, zeta=zeta, NFP=grid.NFP, sym=grid.sym)
else:
grid_w = grid
transforms["w"] = Transform(
grid,
grid_w,
DoubleFourierSeries(
M=kwargs.get("M_booz", 2 * obj.M),
N=kwargs.get("N_booz", 2 * obj.N),
Expand Down
17 changes: 7 additions & 10 deletions desc/objectives/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class QuasisymmetryBoozer(_Objective):
reverse mode and forward over reverse mode respectively.
grid : Grid, optional
Collocation grid containing the nodes to evaluate at.
Must be a LinearGrid with a single flux surface and sym=False.
Must be a LinearGrid with sym=False.
Defaults to ``LinearGrid(M=M_booz, N=N_booz)``.
helicity : tuple, optional
Type of quasi-symmetry (M, N). Default = quasi-axisymmetry (1, 0).
Expand Down Expand Up @@ -122,12 +122,6 @@ def build(self, use_jit=True, verbose=1):
grid = self._grid

errorif(grid.sym, ValueError, "QuasisymmetryBoozer grid must be non-symmetric")
errorif(
grid.num_rho != 1,
ValueError,
"QuasisymmetryBoozer grid must be on a single surface. "
"To target multiple surfaces, use multiple objectives.",
)
warnif(
grid.num_theta < 2 * eq.M,
RuntimeWarning,
Expand Down Expand Up @@ -195,7 +189,7 @@ def compute(self, params, constants=None):
Returns
-------
f : ndarray
Quasi-symmetry flux function error at each node (T^3).
Symmetry breaking harmonics of B (T).

"""
if constants is None:
Expand All @@ -207,8 +201,11 @@ def compute(self, params, constants=None):
transforms=constants["transforms"],
profiles=constants["profiles"],
)
B_mn = constants["matrix"] @ data["|B|_mn"]
return B_mn[constants["idx"]]
B_mn = data["|B|_mn"].reshape((constants["transforms"]["grid"].num_rho, -1))
B_mn = constants["matrix"] @ B_mn.T
# output order = (rho, mn).flatten(), ie all the surfaces concatenated
# one after the other
return B_mn[constants["idx"]].T.flatten()

@property
def helicity(self):
Expand Down
Loading
Loading