Skip to content

Commit

Permalink
Fix bug in Fourier bounce with interpolation of b_sup_z
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Aug 27, 2024
1 parent 04f87a3 commit 1a24a43
Show file tree
Hide file tree
Showing 15 changed files with 486 additions and 455 deletions.
2 changes: 1 addition & 1 deletion desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.special import roots_legendre

from ..backend import fori_loop, jnp
from ..integrals import surface_averages_map
from ..integrals.surface_integral import surface_averages_map
from .data_index import register_compute_fun


Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import (
from ..integrals.surface_integral import (
surface_averages,
surface_integrals_map,
surface_max,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import cond, jnp

from ..integrals import surface_averages, surface_integrals
from ..integrals.surface_integral import surface_averages, surface_integrals
from .data_index import register_compute_fun
from .utils import cumtrapz, dot, safediv

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_integrals_map
from ..integrals.surface_integral import surface_integrals_map
from .data_index import register_compute_fun
from .utils import dot

Expand Down
168 changes: 85 additions & 83 deletions desc/integrals/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,86 @@ def N(self):
"""Chebyshev spectral resolution."""
return self.cheb.shape[-1]

Check warning on line 320 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L320

Added line #L320 was not covered by tests

def isomorphism_to_C1(self, y):
"""Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ².
Maps row x of y to z = y + f(x) where f(x) = x * |domain|.
Parameters
----------
y : jnp.ndarray
Shape (..., y.shape[-2], y.shape[-1]).
Second to last axis iterates the rows.
Returns
-------
z : jnp.ndarray
Shape y.shape.
Isomorphic coordinates.
"""
assert y.ndim >= 2
z_shift = jnp.arange(y.shape[-2]) * (self.domain[-1] - self.domain[0])
z = y + z_shift[:, jnp.newaxis]
return z

Check warning on line 343 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L340-L343

Added lines #L340 - L343 were not covered by tests

def isomorphism_to_C2(self, z):
"""Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ.
Returns index x and minimum value y such that
z = f(x) + y where f(x) = x * |domain|.
Parameters
----------
z : jnp.ndarray
Shape z.shape.
Returns
-------
x_idx, y_val : (jnp.ndarray, jnp.ndarray)
Shape z.shape.
Isomorphic coordinates.
"""
x_idx, y_val = jnp.divmod(z - self.domain[0], self.domain[-1] - self.domain[0])
x_idx = x_idx.astype(int)
y_val += self.domain[0]
return x_idx, y_val

Check warning on line 366 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L363-L366

Added lines #L363 - L366 were not covered by tests

def eval1d(self, z, cheb=None):
"""Evaluate piecewise Chebyshev series at coordinates z.
Parameters
----------
z : jnp.ndarray
Shape (..., *cheb.shape[:-2], z.shape[-1]).
Coordinates in [sef.domain[0], ∞).
The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where
``z // domain`` yields the index into the proper Chebyshev series
along the second to last axis of ``cheb`` and ``z % domain`` is
the coordinate value on the domain of that Chebyshev series.
cheb : jnp.ndarray
Shape (..., M, N).
Chebyshev coefficients to use. If not given, uses ``self.cheb``.
Returns
-------
f : jnp.ndarray
Shape z.shape.
Chebyshev basis evaluated at z.
"""
cheb = _chebcast(setdefault(cheb, self.cheb), z)
N = cheb.shape[-1]
x_idx, y = self.isomorphism_to_C2(z)
y = bijection_to_disc(y, self.domain[0], self.domain[1])

Check warning on line 394 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L391-L394

Added lines #L391 - L394 were not covered by tests
# Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z])
# are held in cheb with shape (..., num cheb series, N).
cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2)
f = idct_non_uniform(y, cheb, N)
assert f.shape == z.shape
return f

Check warning on line 400 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L397-L400

Added lines #L397 - L400 were not covered by tests

def intersect2d(self, k=0.0, eps=_eps):
"""Coordinates yᵢ such that f(x, yᵢ) = k(x).
Expand Down Expand Up @@ -434,86 +514,6 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0):
z2 = jnp.where(mask, z2, pad_value)
return z1, z2

Check warning on line 515 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L513-L515

Added lines #L513 - L515 were not covered by tests

def eval1d(self, z, cheb=None):
"""Evaluate piecewise Chebyshev series at coordinates z.
Parameters
----------
z : jnp.ndarray
Shape (..., *cheb.shape[:-2], z.shape[-1]).
Coordinates in [sef.domain[0], ∞).
The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where
``z // domain`` yields the index into the proper Chebyshev series
along the second to last axis of ``cheb`` and ``z % domain`` is
the coordinate value on the domain of that Chebyshev series.
cheb : jnp.ndarray
Shape (..., M, N).
Chebyshev coefficients to use. If not given, uses ``self.cheb``.
Returns
-------
f : jnp.ndarray
Shape z.shape.
Chebyshev basis evaluated at z.
"""
cheb = _chebcast(setdefault(cheb, self.cheb), z)
N = cheb.shape[-1]
x_idx, y = self.isomorphism_to_C2(z)
y = bijection_to_disc(y, self.domain[0], self.domain[1])
# Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z])
# are held in cheb with shape (..., num cheb series, N).
cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2)
f = idct_non_uniform(y, cheb, N)
assert f.shape == z.shape
return f

def isomorphism_to_C1(self, y):
"""Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ².
Maps row x of y to z = y + f(x) where f(x) = x * |domain|.
Parameters
----------
y : jnp.ndarray
Shape (..., y.shape[-2], y.shape[-1]).
Second to last axis iterates the rows.
Returns
-------
z : jnp.ndarray
Shape y.shape.
Isomorphic coordinates.
"""
assert y.ndim >= 2
z_shift = jnp.arange(y.shape[-2]) * (self.domain[-1] - self.domain[0])
z = y + z_shift[:, jnp.newaxis]
return z

def isomorphism_to_C2(self, z):
"""Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ.
Returns index x and minimum value y such that
z = f(x) + y where f(x) = x * |domain|.
Parameters
----------
z : jnp.ndarray
Shape z.shape.
Returns
-------
x_idx, y_val : (jnp.ndarray, jnp.ndarray)
Shape z.shape.
Isomorphic coordinates.
"""
x_idx, y_val = jnp.divmod(z - self.domain[0], self.domain[-1] - self.domain[0])
x_idx = x_idx.astype(int)
y_val += self.domain[0]
return x_idx, y_val

def _check_shape(self, z1, z2, k):
"""Return shapes that broadcast with (k.shape[0], *self.cheb.shape[:-2], W)."""
# Ensure pitch batch dim exists and add back dim to broadcast with wells.
Expand All @@ -533,7 +533,9 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs):
z1, z2 : jnp.ndarray
Shape must broadcast with (*self.cheb.shape[:-2], W).
``z1``, ``z2`` holds intersects satisfying ∂f/∂y <= 0, ∂f/∂y >= 0,
respectively.
respectively. The points are grouped and ordered such that the
straight line path between the intersects in ``z1`` and ``z2``
resides in the epigraph of f.
k : jnp.ndarray
Shape must broadcast with *self.cheb.shape[:-2].
k such that fₓ(yᵢ) = k.
Expand Down Expand Up @@ -582,8 +584,8 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs):
assert not err_1[idx], "Intersects have an inversion.\n"
assert not err_2[idx], "Detected discontinuity.\n"
assert not err_3[idx], (

Check warning on line 586 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L582-L586

Added lines #L582 - L586 were not covered by tests
"Detected f > k in well, implying a path between z1 and z2 "
"is in hypograph(f). Increase Chebyshev resolution.\n"
"Detected f > k in well, implying the straight line path between "
"z1 and z2 is in hypograph(f). Increase spectral resolution.\n"
f"{f_midpoint[idx][mask[idx]]} > {k[idx] + self._eps}"
)
idx = (slice(None), *l)
Expand Down
52 changes: 39 additions & 13 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from interpax import CubicHermiteSpline
from orthax.legendre import leggauss

from desc.backend import jnp
from desc.backend import jnp, rfft2
from desc.integrals.basis import FourierChebyshevBasis
from desc.integrals.bounce_utils import (
_check_bounce_points,
Expand All @@ -13,12 +13,7 @@
interp_to_argmin_B_soft,
plot_ppoly,
)
from desc.integrals.interp_utils import (
interp_rfft2,
irfft2_non_uniform,
polyder_vec,
transform_to_desc,
)
from desc.integrals.interp_utils import interp_rfft2, irfft2_non_uniform, polyder_vec
from desc.integrals.quad_utils import (
automorphism_sin,
bijection_from_disc,
Expand Down Expand Up @@ -70,13 +65,38 @@ def _transform_to_clebsch(grid, desc_from_clebsch, M, N, B):
# which is not a tensor product node set in DESC space.
xq=desc_from_clebsch[:, 1:].reshape(grid.num_rho, -1, 2),
f=grid.meshgrid_reshape(B, order="rtz")[:, jnp.newaxis],
# Real fft over poloidal since usually num theta > num zeta.
axes=(-1, -2),
).reshape(grid.num_rho, M, N),
domain=Bounce2D.domain,
)
return T, B

Check warning on line 73 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L73

Added line #L73 was not covered by tests


def _transform_to_desc(grid, f):
"""Transform to DESC spectral domain.
Parameters
----------
grid : Grid
Tensor-product grid in (θ, ζ) with uniformly spaced nodes in
(2π × 2π) poloidal and toroidal coordinates.
f : jnp.ndarray
Function evaluated on ``grid``.
Returns
-------
a : jnp.ndarray
Shape (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta)
Complex coefficients of 2D real FFT.
"""
f = grid.meshgrid_reshape(f, order="rtz")
a = rfft2(f, axes=(-1, -2), norm="forward")
assert a.shape == (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta)
return a

Check warning on line 97 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L94-L97

Added lines #L94 - L97 were not covered by tests


# TODO:
# After GitHub issue #1034 is resolved, we should pass in the previous
# θ(α) coordinates as an initial guess for the next coordinate mapping.
Expand Down Expand Up @@ -292,14 +312,20 @@ def __init__(
)
self._m = grid.num_theta
self._n = grid.num_zeta
self._b_sup_z = jnp.expand_dims(
transform_to_desc(grid, jnp.abs(data["B^zeta"]) / data["|B|"] * Lref),
axis=1,
)
self._x, self._w = get_quadrature(quad, automorphism)

Check warning on line 315 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L313-L315

Added lines #L313 - L315 were not covered by tests

# Compute global splines.
T, B = _transform_to_clebsch(grid, desc_from_clebsch, M, N, data["|B|"] / Bref)
self._b_sup_z = _transform_to_desc(

Check warning on line 318 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L318

Added line #L318 was not covered by tests
grid,
jnp.abs(data["B^zeta"]) / data["|B|"] * Lref,
)[:, jnp.newaxis]
T, B = _transform_to_clebsch(

Check warning on line 322 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L322

Added line #L322 was not covered by tests
grid,
desc_from_clebsch,
M,
N,
data["|B|"] / Bref,
)
# peel off field lines
alphas = get_alpha(

Check warning on line 330 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L330

Added line #L330 was not covered by tests
alpha_0,
Expand Down Expand Up @@ -535,7 +561,7 @@ def _integrate(self, z1, z2, pitch, integrand, f):
pitch=pitch[..., jnp.newaxis, jnp.newaxis],
)
/ irfft2_non_uniform(
Q, self._b_sup_z, self._m, self._n, axes=(-1, -2)
xq=Q, a=self._b_sup_z, M=self._n, N=self._m, axes=(-1, -2)
).reshape(shape),
self._w,
)
Expand Down
4 changes: 2 additions & 2 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ def _check_bounce_points(z1, z2, pitch, knots, B, plot=True, **kwargs):
assert not err_2[p, s], "Detected discontinuity.\n"
assert not err_3, (

Check warning on line 279 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L275-L279

Added lines #L275 - L279 were not covered by tests
f"Detected |B| = {Bs_midpoint[mask[p, s]]} > {1 / pitch[p, s] + eps} "
"= 1/λ in well, implying the straight line path between bounce points "
"is in hypograph(|B|). Use more knots.\n"
"= 1/λ in well, implying the straight line path between "
"bounce points is in hypograph(|B|). Use more knots.\n"
)
if plot:
plot_ppoly(

Check warning on line 285 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L284-L285

Added lines #L284 - L285 were not covered by tests
Expand Down
27 changes: 1 addition & 26 deletions desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)):
M : int
Spectral resolution of ``a`` along ``axes[0]``.
N : int
Spectral resolution of ``a`` along ``axes[-1]``.
Spectral resolution of ``a`` along ``axes[1]``.
axes : tuple[int, int]
Axes along which to transform.
Expand Down Expand Up @@ -295,31 +295,6 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)):
return fq

Check warning on line 295 in desc/integrals/interp_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/interp_utils.py#L294-L295

Added lines #L294 - L295 were not covered by tests


def transform_to_desc(grid, f):
"""Transform to DESC spectral domain.
Parameters
----------
grid : Grid
Tensor-product grid in (θ, ζ) with uniformly spaced nodes in
(2π × 2π) poloidal and toroidal coordinates.
f : jnp.ndarray
Function evaluated on ``grid``.
Returns
-------
a : jnp.ndarray
Shape (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta)
Complex coefficients of 2D real FFT.
"""
f = grid.meshgrid_reshape(f, order="rtz")
# Real fft done over poloidal since num_theta > num_zeta usually.
a = rfft2(f, axes=(-1, -2), norm="forward")
assert a.shape == (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta)
return a


def cheb_from_dct(a, axis=-1):
"""Get discrete Chebyshev transform from discrete cosine transform.
Expand Down
Loading

0 comments on commit 1a24a43

Please sign in to comment.