Skip to content

Commit

Permalink
Add ability to compute length along field line with fft
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Oct 6, 2024
1 parent 56dcbb2 commit 30d511e
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 52 deletions.
2 changes: 1 addition & 1 deletion desc/integrals/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def eval1d(self, z, cheb=None):
----------
z : jnp.ndarray
Shape (..., *cheb.shape[:-2], z.shape[-1]).
Coordinates in [sef.domain[0], ∞).
Coordinates in [self.domain[0], ∞).
The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where
``z // domain`` yields the index into the proper Chebyshev series
along the second to last axis of ``cheb`` and ``z % domain`` is
Expand Down
83 changes: 64 additions & 19 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from desc.integrals.interp_utils import (
cheb_from_dct,
cheb_pts,
idct_non_uniform,
interp_rfft2,
irfft2_non_uniform,
polyder_vec,
Expand Down Expand Up @@ -92,15 +93,17 @@ def _transform_to_clebsch(grid, nodes, f, is_reshaped=False):
f = grid.meshgrid_reshape(f, "rtz")

M, N = nodes.shape[-3], nodes.shape[-2]
nodes = nodes.reshape(*nodes.shape[:-3], M * N, 2)
return FourierChebyshevSeries(
f=interp_rfft2(
# Interpolate to nodes in Clebsch space,
# which is not a tensor product node set in DESC space.
xq=nodes.reshape(*nodes.shape[:-3], M * N, 2),
xq0=nodes[..., 0],
xq1=nodes[..., 1],
f=f[..., jnp.newaxis, :, :],
domain1=(0, 2 * jnp.pi / grid.NFP),
axes=(-1, -2),
).reshape(*nodes.shape[:-3], M, N),
).reshape(*nodes.shape[:-2], M, N),
domain=(0, 2 * jnp.pi),
)

Expand Down Expand Up @@ -185,12 +188,12 @@ def _transform_to_clebsch_1d(grid, alpha, theta, B, N_B, is_reshaped=False):
T = FourierChebyshevSeries(f=theta, domain=(0, 2 * jnp.pi)).compute_cheb(alpha)
T.stitch()
theta = T.evaluate(N_B)
xq = jnp.stack(
[theta, jnp.broadcast_to(cheb_pts(N_B, domain=T.domain), theta.shape)], axis=-1
).reshape(*alpha.shape[:-1], alpha.shape[-1] * N_B, 2)
zeta = jnp.broadcast_to(cheb_pts(N_B, domain=T.domain), theta.shape)

shape = (*alpha.shape[:-1], alpha.shape[-1] * N_B)
B = interp_rfft2(
xq=xq,
theta.reshape(shape),
zeta.reshape(shape),
f=B[..., jnp.newaxis, :, :],
domain1=(0, 2 * jnp.pi / grid.NFP),
axes=(-1, -2),
Expand Down Expand Up @@ -501,8 +504,8 @@ def __init__(
self._T, self._B = _transform_to_clebsch_1d(
grid, alpha, theta, data["|B|"] / Bref, N_B, is_reshaped
)
self._b_sup_z = _transform_to_desc(
grid, jnp.abs(data["B^zeta"]) / data["|B|"] * Lref, is_reshaped
self._B_sup_z = _transform_to_desc(
grid, jnp.abs(data["B^zeta"]) * Lref / Bref, is_reshaped
)
assert self._T.M == self._B.M == num_transit
assert self._T.N == theta.shape[-1]
Expand Down Expand Up @@ -743,24 +746,26 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot):
Shape (num_pitch, ) or (num_pitch, L).
f : list[jnp.ndarray]
Shape (m, n) or (L, m, n).
f : list[jnp.ndarray]
f_vec : list[jnp.ndarray]
Shape (m, n, 3) or (L, m, n, 3).
"""
z1, z2 = points
shape = [*z1.shape, self._x.size]

# This is ζ along the field line.
# These are the ζ coordinates of the quadrature points.
# Shape is (num_pitch, L, number of points to interpolate onto).
zeta = flatten_matrix(
bijection_from_disc(self._x, z1[..., jnp.newaxis], z2[..., jnp.newaxis])
)
# Note self._T expects shape (num_pitch, L) if T.cheb.shape[0] is L.
# These are the (θ, ζ) coordinates of the quadrature points.
Q = jnp.stack([self._T.eval1d(zeta), zeta], axis=-1)
# These are the θ coordinates of the quadrature points.
theta = self._T.eval1d(zeta)

b_sup_z = irfft2_non_uniform(
xq=Q,
a=self._b_sup_z[..., jnp.newaxis, :, :],
B_sup_z = irfft2_non_uniform(
theta,
zeta,
a=self._B_sup_z[..., jnp.newaxis, :, :],
M=self._n,
N=self._m,
domain1=(0, 2 * jnp.pi / self._NFP),
Expand All @@ -769,7 +774,8 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot):
B = self._B.eval1d(zeta)
f = [
interp_rfft2(
Q,
theta,
zeta,
f_i[..., jnp.newaxis, :, :],
domain1=(0, 2 * jnp.pi / self._NFP),
axes=(-1, -2),
Expand All @@ -778,7 +784,8 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot):
]
f_vec = [
interp_rfft2(
Q[..., jnp.newaxis, :],
theta[..., jnp.newaxis],
zeta[..., jnp.newaxis],
f_i[..., jnp.newaxis, :, :, :],
domain1=(0, 2 * jnp.pi / self._NFP),
axes=(-2, -3),
Expand All @@ -794,7 +801,8 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot):
pitch=1 / pitch_inv[..., jnp.newaxis],
zeta=zeta,
)
/ b_sup_z
* B
/ B_sup_z
)
.reshape(shape)
.dot(self._w)
Expand All @@ -806,13 +814,50 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot):
_check_interp(
# num_alpha is 1, num_rho, num_pitch, num_well, num_quad
(1, *shape),
*map(_swap_pl, (zeta, b_sup_z, B)),
*map(_swap_pl, (zeta, B_sup_z, B)),
result,
list(map(_swap_pl, f)),
plot,
)
return result

def compute_length(self, quad=None):
"""Compute the proper length of the field line ∫ dℓ / |B|.
Parameters
----------
quad : tuple[jnp.ndarray]
Quadrature points xₖ and weights wₖ for the approximate evaluation
of the integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ).
Should not use more points than half Chebyshev resolution of |B|.
Returns
-------
length : jnp.ndarray
Shape (L, ).
"""
# Gauss quadrature captures double frequency of Chebyshev series.
x, w = leggauss(self._B.N // 2) if quad is None else quad

# TODO: There exits a fast transform from Chebyshev series to Legendre nodes.
theta = idct_non_uniform(x, self._T.cheb[..., jnp.newaxis, :], self._T.N)
zeta = jnp.broadcast_to(bijection_from_disc(x, 0, 2 * jnp.pi), theta.shape)

shape = (-1, self._T.M * w.size) # (num_rho, num transit * w.size)
B_sup_z = irfft2_non_uniform(
theta.reshape(shape),
zeta.reshape(shape),
a=self._B_sup_z[..., jnp.newaxis, :, :],
M=self._n,
N=self._m,
domain1=(0, 2 * jnp.pi / self._NFP),
axes=(-1, -2),
).reshape(-1, self._T.M, w.size)

# Gradient of change of variable bijection from [−1, 1] → [0, 2π] is π.
return (1 / B_sup_z).dot(w).sum(axis=-1) * jnp.pi

def plot(self, l, pitch_inv=None, **kwargs):
"""Plot the field line and bounce points of the given pitch angles.
Expand Down
63 changes: 35 additions & 28 deletions desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,26 +218,29 @@ def irfft_non_uniform(xq, a, n, domain=(0, 2 * jnp.pi), axis=-1):


def interp_rfft2(
xq, f, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1)
xq0, xq1, f, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1)
):
"""Interpolate real-valued ``f`` to ``xq`` with FFT.
"""Interpolate real-valued ``f`` to coordinates ``(xq0,xq1)`` with FFT.
Parameters
----------
xq : jnp.ndarray
Shape (..., 2).
Real query points where interpolation is desired.
Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(f.shape,axes)``.
Last axis must hold coordinates for a given point. The coordinates stored
along ``xq[...,0]`` (``xq[...,1]``) must be the same coordinate enumerated
across axis ``min(axes)`` (``max(axes)``) of the function values ``f``.
xq0 : jnp.ndarray
Real query points of coordinate in ``domain0`` where interpolation is desired.
Shape must broadcast with shape ``np.delete(a.shape,axes)``.
The coordinates stored here must be the same coordinate enumerated
across axis ``min(axes)`` of the function values ``f``.
xq1 : jnp.ndarray
Real query points of coordinate in ``domain1`` where interpolation is desired.
Shape must broadcast with shape ``np.delete(a.shape,axes)``.
The coordinates stored here must be the same coordinate enumerated
across axis ``max(axes)`` of the function values ``f``.
f : jnp.ndarray
Shape (..., f.shape[-2], f.shape[-1]).
Real function values on uniform tensor-product grid over an open period.
domain0 : tuple[float]
Domain of coordinate specified by ``xq[...,0]`` over which samples were taken.
Domain of coordinate specified by ``xq0`` over which samples were taken.
domain1 : tuple[float]
Domain of coordinate specified by ``xq[...,1]`` over which samples were taken.
Domain of coordinate specified by ``xq1`` over which samples were taken.
axes : tuple[int]
Axes along which to transform.
The real transform is done along ``axes[1]``, so it will be more
Expand All @@ -251,26 +254,28 @@ def interp_rfft2(
"""
a = rfft2(f, axes=axes, norm="forward")
fq = irfft2_non_uniform(
xq, a, f.shape[axes[0]], f.shape[axes[1]], domain0, domain1, axes
xq0, xq1, a, f.shape[axes[0]], f.shape[axes[1]], domain0, domain1, axes
)
return fq


def irfft2_non_uniform(
xq, a, M, N, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1)
xq0, xq1, a, M, N, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1)
):
"""Evaluate Fourier coefficients ``a`` at ``xq``.
"""Evaluate Fourier coefficients ``a`` at coordinates ``(xq0,xq1)``.
Parameters
----------
xq : jnp.ndarray
Shape (..., 2).
Real query points where interpolation is desired.
Last axis must hold coordinates for a given point.
Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(a.shape,axes)``.
Last axis must hold coordinates for a given point. The coordinates stored
along ``xq[...,0]`` (``xq[...,1]``) must be the same coordinate enumerated
across axis ``min(axes)`` (``max(axes)``) of the Fourier coefficients ``a``.
xq0 : jnp.ndarray
Real query points of coordinate in ``domain0`` where interpolation is desired.
Shape must broadcast with shape ``np.delete(a.shape,axes)``.
The coordinates stored here must be the same coordinate enumerated
across axis ``min(axes)`` of the Fourier coefficients ``a``.
xq1 : jnp.ndarray
Real query points of coordinate in ``domain1`` where interpolation is desired.
Shape must broadcast with shape ``np.delete(a.shape,axes)``.
The coordinates stored here must be the same coordinate enumerated
across axis ``max(axes)`` of the Fourier coefficients ``a``.
a : jnp.ndarray
Shape (..., a.shape[-2], a.shape[-1]).
Fourier coefficients ``a=rfft2(f,axes=axes,norm="forward")``.
Expand All @@ -279,9 +284,9 @@ def irfft2_non_uniform(
N : int
Spectral resolution of ``a`` along ``axes[1]``.
domain0 : tuple[float]
Domain of coordinate specified by ``xq[...,0]`` over which samples were taken.
Domain of coordinate specified by ``xq0`` over which samples were taken.
domain1 : tuple[float]
Domain of coordinate specified by ``xq[...,1]`` over which samples were taken.
Domain of coordinate specified by ``xq1`` over which samples were taken.
axes : tuple[int]
Axes along which to transform.
Expand All @@ -291,7 +296,7 @@ def irfft2_non_uniform(
Real function value at query points.
"""
errorif(not (len(axes) == xq.shape[-1] == 2), msg="This is a 2D transform.")
errorif(len(axes) != 2, msg="This is a 2D transform.")
errorif(a.ndim < 2, msg=f"Dimension mismatch, a.shape: {a.shape}.")

# |a| << |basis|, so move a instead of basis
Expand All @@ -307,13 +312,15 @@ def irfft2_non_uniform(
domain = (domain0, domain1)
m = jnp.fft.fftfreq(M, d=np.diff(domain[idx[0]]) / (2 * jnp.pi) / M)
n = jnp.fft.rfftfreq(N, d=np.diff(domain[idx[1]]) / (2 * jnp.pi) / N)
xq = xq - jnp.array([domain0[0], domain1[0]])
xq0 = xq0 - domain0[0]
xq1 = xq1 - domain1[0]
xq = (xq0, xq1)

basis = jnp.exp(
1j
* (
(m * xq[..., idx[0], jnp.newaxis])[..., jnp.newaxis]
+ (n * xq[..., idx[1], jnp.newaxis])[..., jnp.newaxis, :]
(m * xq[idx[0]][..., jnp.newaxis])[..., jnp.newaxis]
+ (n * xq[idx[1]][..., jnp.newaxis])[..., jnp.newaxis, :]
)
)
fq = 2.0 * (basis * a).real.sum(axis=(-2, -1))
Expand Down
20 changes: 20 additions & 0 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,26 @@ def test_bounce2d_checks(self):

# 10. Plotting
fig, ax = bounce.plot(l, pitch_inv[l], include_legend=False, show=False)

length = bounce.compute_length()
np.testing.assert_allclose(
length,
# Computed through <L|r,a> with Simpson's rule at 800 nodes (over-resolved).
# The difference is likely not due to quadrature error, rather interpolation
# error as the data points for ``bounce.compute_length()`` come from Fourier
# series of |B|, while those for <L|r,a> come from Fourier series of
# plasma boundary. Also, currently JAX has a bug with DCT and FFT that limit
# the accuracy to something comparable to 32 bit.
[
384.77892007,
361.60220181,
345.33817065,
333.00781712,
352.16277188,
440.09424799,
],
rtol=3e-3,
)
return fig

@staticmethod
Expand Down
9 changes: 5 additions & 4 deletions tests/test_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,19 @@ def test_interp_rfft(self, func, n, domain):
)
def test_interp_rfft2(self, func, m, n, domain0, domain1):
"""Test non-uniform FFT interpolation."""
xq = np.array([[7.34, 1.10134, 2.28, 1e3 * np.e], [1.1, 3.78432, 8.542, 0]]).T
theta = np.array([7.34, 1.10134, 2.28, 1e3 * np.e])
zeta = np.array([1.1, 3.78432, 8.542, 0])
x = np.linspace(domain0[0], domain0[1], m, endpoint=False)
y = np.linspace(domain1[0], domain1[1], n, endpoint=False)
x, y = map(np.ravel, list(np.meshgrid(x, y, indexing="ij")))
truth = func(xq[..., 0], xq[..., 1])
truth = func(theta, zeta)
f = func(x, y).reshape(m, n)
np.testing.assert_allclose(
interp_rfft2(xq, f, domain0, domain1, axes=(-2, -1)),
interp_rfft2(theta, zeta, f, domain0, domain1, axes=(-2, -1)),
truth,
)
np.testing.assert_allclose(
interp_rfft2(xq, f, domain0, domain1, axes=(-1, -2)),
interp_rfft2(theta, zeta, f, domain0, domain1, axes=(-1, -2)),
truth,
)

Expand Down

0 comments on commit 30d511e

Please sign in to comment.