Skip to content

Commit

Permalink
Updating epsilon and gamma objectives with Bounce2D
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Oct 20, 2024
1 parent 3227c8f commit 8a8e8f2
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 175 deletions.
12 changes: 6 additions & 6 deletions desc/compute/_neoclassical.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Compute functions for neoclassical transport.
Performance will improve significantly by resolving these GitHub issues.
* ``1154`` Improve coordinate mapping performance
* ``1294`` Nonuniform fast transforms
* ``1303`` Patch for differentiable code with dynamic shapes
* ``1206`` Upsample data above midplane to full grid assuming stellarator symmetry
* ``1034`` Optimizers/objectives with auxilary output
* ``1154`` Improve coordinate mapping performance
* ``1294`` Nonuniform fast transforms
* ``1303`` Patch for differentiable code with dynamic shapes
* ``1206`` Upsample data above midplane to full grid assuming stellarator symmetry
* ``1034`` Optimizers/objectives with auxiliary output
If memory is still an issue, consider computing one pitch at a time. This
can be done by copy-pasting the code given at
https://github.com/PlasmaControl/DESC/pull/1003#discussion_r1780459450.
Note that imap supports computing in batches, so that can also be used.
Make sure to benchmark whether this reduces memory in an optimization.
"""

from functools import partial
Expand Down
12 changes: 9 additions & 3 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def map_coordinates( # noqa: C901
f"don't have recipe to compute partial derivative {key}",
)

profiles = get_profiles(inbasis + basis_derivs, eq)
profiles = (
kwargs["profiles"]
if "profiles" in kwargs
else get_profiles(inbasis + basis_derivs, eq)
)

# TODO: make this work for permutations of in/out basis
if outbasis == ("rho", "theta", "zeta"):
Expand All @@ -114,7 +118,9 @@ def map_coordinates( # noqa: C901
iota = kwargs.pop("iota")
else:
if profiles["iota"] is None:
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params)
profiles["iota"] = eq.get_profile(
["iota", "iota_r"], params=params, **kwargs
)
iota = profiles["iota"].compute(Grid(coords, sort=False, jitable=True))
return _map_clebsch_coordinates(
coords=coords,
Expand Down Expand Up @@ -143,7 +149,7 @@ def map_coordinates( # noqa: C901

# do surface average to get iota once
if "iota" in profiles and profiles["iota"] is None:
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params)
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params, **kwargs)
params["i_l"] = profiles["iota"].params

rhomin = kwargs.pop("rhomin", tol / 10)
Expand Down
10 changes: 5 additions & 5 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,13 +909,13 @@ def need_src(name):
# the compute logic assume input data is evaluated on those coordinates.
# We exclude these from the depXdx sets below since the grids we will
# use to compute those dependencies are coordinate-blind.
# Example, "<L|r,a>" has coordinates="r", but requires computing on
# field line following source grid.
# Example, "fieldline length" has coordinates="r", but requires computing
# on field line following source grid.
return bool(data_index[p][name]["source_grid_requirement"])

# Need to call _grow_seeds so that some other quantity like K = 2 * <L|r,a>,
# which does not need a source grid to evaluate, does not compute <L|r,a> on a
# grid that does not follow field lines.
# Need to call _grow_seeds so that e.g. "effective ripple*" which does not
# need a source grid to evaluate, still computes "effective ripple 3/2*"
# on a grid whose source grid follows field lines.
# Maybe this can help explain:
# https://github.com/PlasmaControl/DESC/pull/1024#discussion_r1664918897.
need_src_deps = _grow_seeds(p, set(filter(need_src, deps)), deps)
Expand Down
19 changes: 11 additions & 8 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,13 @@ def meshgrid_reshape(self, x, order):
x = jnp.transpose(x, newax)
return x

def to_numpy(self):
"""Convert all jax array attributes to numpy arrays."""
for attr in self.__dict__:
value = getattr(self, attr)
if isinstance(value, jnp.ndarray):
setattr(self, attr, np.array(value))


class Grid(_Grid):
"""Collocation grid with custom node placement.
Expand Down Expand Up @@ -808,9 +815,9 @@ def create_meshgrid(
a, b, c = jnp.atleast_1d(*nodes)
if spacing is None:
errorif(coordinates[0] != "r", NotImplementedError)
da = _midpoint_spacing(a)
db = _periodic_spacing(b, period[1])[1]
dc = _periodic_spacing(c, period[2])[1] * NFP
da = _midpoint_spacing(a, jnp=jnp)
db = _periodic_spacing(b, period[1], jnp=jnp)[1]
dc = _periodic_spacing(c, period[2], jnp=jnp)[1] * NFP
else:
da, db, dc = spacing

Expand Down Expand Up @@ -839,10 +846,7 @@ def create_meshgrid(
repeat(unique_a_idx // b.size, b.size, total_repeat_length=a.size * b.size),
c.size,
)
inverse_b_idx = jnp.tile(
unique_b_idx,
a.size * c.size,
)
inverse_b_idx = jnp.tile(unique_b_idx, a.size * c.size)
inverse_c_idx = repeat(unique_c_idx // (a.size * b.size), (a.size * b.size))
return Grid(
nodes=nodes,
Expand All @@ -853,7 +857,6 @@ def create_meshgrid(
NFP=NFP,
sort=False,
is_meshgrid=True,
jitable=True,
_unique_rho_idx=unique_a_idx,
_unique_poloidal_idx=unique_b_idx,
_unique_zeta_idx=unique_c_idx,
Expand Down
11 changes: 10 additions & 1 deletion desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def reshape_data(grid, *arys):
# θ(α, ζ) since these are related to lambda.

@staticmethod
def compute_theta(eq, X=16, Y=32, rho=1.0, clebsch=None, **kwargs):
def compute_theta(eq, X=16, Y=32, rho=1.0, iota=None, clebsch=None, **kwargs):
"""Return DESC coordinates θ of (α,ζ) Fourier Chebyshev basis nodes.
Parameters
Expand All @@ -417,10 +417,16 @@ def compute_theta(eq, X=16, Y=32, rho=1.0, clebsch=None, **kwargs):
Grid resolution in toroidal direction for Clebsch coordinate grid.
Preferably power of 2.
rho : float or jnp.ndarray
Shape (num rho, ).
Flux surfaces labels in [0, 1] on which to compute.
iota : float or jnp.ndarray
Shape (num rho, ).
Optional, rotational transform on the flux surfaces to compute on.
clebsch : jnp.ndarray
Shape (num rho * X * Y, 3).
Optional, precomputed Clebsch coordinate tensor-product grid (ρ, α, ζ).
``FourierChebyshevSeries.nodes(X,Y,rho,domain=(0,2*jnp.pi))``.
If supplied ``rho`` is ignored.
kwargs
Additional parameters to supply to the coordinate mapping function.
See ``desc.equilibrium.Equilibrium.map_coordinates``.
Expand All @@ -435,6 +441,9 @@ def compute_theta(eq, X=16, Y=32, rho=1.0, clebsch=None, **kwargs):
"""
if clebsch is None:
clebsch = FourierChebyshevSeries.nodes(X, Y, rho, domain=(0, 2 * jnp.pi))
if iota is not None:
iota = jnp.atleast_1d(iota)
kwargs["iota"] = jnp.broadcast_to(iota, shape=(Y, X, iota.size)).T.ravel()
return eq.map_coordinates(
coords=clebsch,
inbasis=("rho", "alpha", "zeta"),
Expand Down
Loading

0 comments on commit 8a8e8f2

Please sign in to comment.