diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py index dbcb7372a..8369ce583 100644 --- a/desc/integrals/basis.py +++ b/desc/integrals/basis.py @@ -372,7 +372,7 @@ def eval1d(self, z, cheb=None): ---------- z : jnp.ndarray Shape (..., *cheb.shape[:-2], z.shape[-1]). - Coordinates in [sef.domain[0], ∞). + Coordinates in [self.domain[0], ∞). The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where ``z // domain`` yields the index into the proper Chebyshev series along the second to last axis of ``cheb`` and ``z % domain`` is diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index 7aa2d425a..dadde2662 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -20,6 +20,7 @@ from desc.integrals.interp_utils import ( cheb_from_dct, cheb_pts, + idct_non_uniform, interp_rfft2, irfft2_non_uniform, polyder_vec, @@ -92,15 +93,17 @@ def _transform_to_clebsch(grid, nodes, f, is_reshaped=False): f = grid.meshgrid_reshape(f, "rtz") M, N = nodes.shape[-3], nodes.shape[-2] + nodes = nodes.reshape(*nodes.shape[:-3], M * N, 2) return FourierChebyshevSeries( f=interp_rfft2( # Interpolate to nodes in Clebsch space, # which is not a tensor product node set in DESC space. - xq=nodes.reshape(*nodes.shape[:-3], M * N, 2), + xq0=nodes[..., 0], + xq1=nodes[..., 1], f=f[..., jnp.newaxis, :, :], domain1=(0, 2 * jnp.pi / grid.NFP), axes=(-1, -2), - ).reshape(*nodes.shape[:-3], M, N), + ).reshape(*nodes.shape[:-2], M, N), domain=(0, 2 * jnp.pi), ) @@ -185,12 +188,12 @@ def _transform_to_clebsch_1d(grid, alpha, theta, B, N_B, is_reshaped=False): T = FourierChebyshevSeries(f=theta, domain=(0, 2 * jnp.pi)).compute_cheb(alpha) T.stitch() theta = T.evaluate(N_B) - xq = jnp.stack( - [theta, jnp.broadcast_to(cheb_pts(N_B, domain=T.domain), theta.shape)], axis=-1 - ).reshape(*alpha.shape[:-1], alpha.shape[-1] * N_B, 2) + zeta = jnp.broadcast_to(cheb_pts(N_B, domain=T.domain), theta.shape) + shape = (*alpha.shape[:-1], alpha.shape[-1] * N_B) B = interp_rfft2( - xq=xq, + theta.reshape(shape), + zeta.reshape(shape), f=B[..., jnp.newaxis, :, :], domain1=(0, 2 * jnp.pi / grid.NFP), axes=(-1, -2), @@ -501,8 +504,8 @@ def __init__( self._T, self._B = _transform_to_clebsch_1d( grid, alpha, theta, data["|B|"] / Bref, N_B, is_reshaped ) - self._b_sup_z = _transform_to_desc( - grid, jnp.abs(data["B^zeta"]) / data["|B|"] * Lref, is_reshaped + self._B_sup_z = _transform_to_desc( + grid, jnp.abs(data["B^zeta"]) * Lref / Bref, is_reshaped ) assert self._T.M == self._B.M == num_transit assert self._T.N == theta.shape[-1] @@ -743,24 +746,26 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot): Shape (num_pitch, ) or (num_pitch, L). f : list[jnp.ndarray] Shape (m, n) or (L, m, n). - f : list[jnp.ndarray] + f_vec : list[jnp.ndarray] Shape (m, n, 3) or (L, m, n, 3). """ z1, z2 = points shape = [*z1.shape, self._x.size] - # This is ζ along the field line. + # These are the ζ coordinates of the quadrature points. + # Shape is (num_pitch, L, number of points to interpolate onto). zeta = flatten_matrix( bijection_from_disc(self._x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]) ) # Note self._T expects shape (num_pitch, L) if T.cheb.shape[0] is L. - # These are the (θ, ζ) coordinates of the quadrature points. - Q = jnp.stack([self._T.eval1d(zeta), zeta], axis=-1) + # These are the θ coordinates of the quadrature points. + theta = self._T.eval1d(zeta) - b_sup_z = irfft2_non_uniform( - xq=Q, - a=self._b_sup_z[..., jnp.newaxis, :, :], + B_sup_z = irfft2_non_uniform( + theta, + zeta, + a=self._B_sup_z[..., jnp.newaxis, :, :], M=self._n, N=self._m, domain1=(0, 2 * jnp.pi / self._NFP), @@ -769,7 +774,8 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot): B = self._B.eval1d(zeta) f = [ interp_rfft2( - Q, + theta, + zeta, f_i[..., jnp.newaxis, :, :], domain1=(0, 2 * jnp.pi / self._NFP), axes=(-1, -2), @@ -778,7 +784,8 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot): ] f_vec = [ interp_rfft2( - Q[..., jnp.newaxis, :], + theta[..., jnp.newaxis], + zeta[..., jnp.newaxis], f_i[..., jnp.newaxis, :, :, :], domain1=(0, 2 * jnp.pi / self._NFP), axes=(-2, -3), @@ -794,7 +801,8 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot): pitch=1 / pitch_inv[..., jnp.newaxis], zeta=zeta, ) - / b_sup_z + * B + / B_sup_z ) .reshape(shape) .dot(self._w) @@ -806,13 +814,50 @@ def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot): _check_interp( # num_alpha is 1, num_rho, num_pitch, num_well, num_quad (1, *shape), - *map(_swap_pl, (zeta, b_sup_z, B)), + *map(_swap_pl, (zeta, B_sup_z, B)), result, list(map(_swap_pl, f)), plot, ) return result + def compute_length(self, quad=None): + """Compute the proper length of the field line ∫ dℓ / |B|. + + Parameters + ---------- + quad : tuple[jnp.ndarray] + Quadrature points xₖ and weights wₖ for the approximate evaluation + of the integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + Should not use more points than half Chebyshev resolution of |B|. + + Returns + ------- + length : jnp.ndarray + Shape (L, ). + + """ + # Gauss quadrature captures double frequency of Chebyshev series. + x, w = leggauss(self._B.N // 2) if quad is None else quad + + # TODO: There exits a fast transform from Chebyshev series to Legendre nodes. + theta = idct_non_uniform(x, self._T.cheb[..., jnp.newaxis, :], self._T.N) + zeta = jnp.broadcast_to(bijection_from_disc(x, 0, 2 * jnp.pi), theta.shape) + + shape = (-1, self._T.M * w.size) # (num_rho, num transit * w.size) + B_sup_z = irfft2_non_uniform( + theta.reshape(shape), + zeta.reshape(shape), + a=self._B_sup_z[..., jnp.newaxis, :, :], + M=self._n, + N=self._m, + domain1=(0, 2 * jnp.pi / self._NFP), + axes=(-1, -2), + ).reshape(-1, self._T.M, w.size) + + # Gradient of change of variable bijection from [−1, 1] → [0, 2π] is π. + return (1 / B_sup_z).dot(w).sum(axis=-1) * jnp.pi + def plot(self, l, pitch_inv=None, **kwargs): """Plot the field line and bounce points of the given pitch angles. diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py index df10490d7..b600c6f84 100644 --- a/desc/integrals/interp_utils.py +++ b/desc/integrals/interp_utils.py @@ -218,26 +218,29 @@ def irfft_non_uniform(xq, a, n, domain=(0, 2 * jnp.pi), axis=-1): def interp_rfft2( - xq, f, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1) + xq0, xq1, f, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1) ): - """Interpolate real-valued ``f`` to ``xq`` with FFT. + """Interpolate real-valued ``f`` to coordinates ``(xq0,xq1)`` with FFT. Parameters ---------- - xq : jnp.ndarray - Shape (..., 2). - Real query points where interpolation is desired. - Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(f.shape,axes)``. - Last axis must hold coordinates for a given point. The coordinates stored - along ``xq[...,0]`` (``xq[...,1]``) must be the same coordinate enumerated - across axis ``min(axes)`` (``max(axes)``) of the function values ``f``. + xq0 : jnp.ndarray + Real query points of coordinate in ``domain0`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``min(axes)`` of the function values ``f``. + xq1 : jnp.ndarray + Real query points of coordinate in ``domain1`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``max(axes)`` of the function values ``f``. f : jnp.ndarray Shape (..., f.shape[-2], f.shape[-1]). Real function values on uniform tensor-product grid over an open period. domain0 : tuple[float] - Domain of coordinate specified by ``xq[...,0]`` over which samples were taken. + Domain of coordinate specified by ``xq0`` over which samples were taken. domain1 : tuple[float] - Domain of coordinate specified by ``xq[...,1]`` over which samples were taken. + Domain of coordinate specified by ``xq1`` over which samples were taken. axes : tuple[int] Axes along which to transform. The real transform is done along ``axes[1]``, so it will be more @@ -251,26 +254,28 @@ def interp_rfft2( """ a = rfft2(f, axes=axes, norm="forward") fq = irfft2_non_uniform( - xq, a, f.shape[axes[0]], f.shape[axes[1]], domain0, domain1, axes + xq0, xq1, a, f.shape[axes[0]], f.shape[axes[1]], domain0, domain1, axes ) return fq def irfft2_non_uniform( - xq, a, M, N, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1) + xq0, xq1, a, M, N, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1) ): - """Evaluate Fourier coefficients ``a`` at ``xq``. + """Evaluate Fourier coefficients ``a`` at coordinates ``(xq0,xq1)``. Parameters ---------- - xq : jnp.ndarray - Shape (..., 2). - Real query points where interpolation is desired. - Last axis must hold coordinates for a given point. - Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(a.shape,axes)``. - Last axis must hold coordinates for a given point. The coordinates stored - along ``xq[...,0]`` (``xq[...,1]``) must be the same coordinate enumerated - across axis ``min(axes)`` (``max(axes)``) of the Fourier coefficients ``a``. + xq0 : jnp.ndarray + Real query points of coordinate in ``domain0`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``min(axes)`` of the Fourier coefficients ``a``. + xq1 : jnp.ndarray + Real query points of coordinate in ``domain1`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``max(axes)`` of the Fourier coefficients ``a``. a : jnp.ndarray Shape (..., a.shape[-2], a.shape[-1]). Fourier coefficients ``a=rfft2(f,axes=axes,norm="forward")``. @@ -279,9 +284,9 @@ def irfft2_non_uniform( N : int Spectral resolution of ``a`` along ``axes[1]``. domain0 : tuple[float] - Domain of coordinate specified by ``xq[...,0]`` over which samples were taken. + Domain of coordinate specified by ``xq0`` over which samples were taken. domain1 : tuple[float] - Domain of coordinate specified by ``xq[...,1]`` over which samples were taken. + Domain of coordinate specified by ``xq1`` over which samples were taken. axes : tuple[int] Axes along which to transform. @@ -291,7 +296,7 @@ def irfft2_non_uniform( Real function value at query points. """ - errorif(not (len(axes) == xq.shape[-1] == 2), msg="This is a 2D transform.") + errorif(len(axes) != 2, msg="This is a 2D transform.") errorif(a.ndim < 2, msg=f"Dimension mismatch, a.shape: {a.shape}.") # |a| << |basis|, so move a instead of basis @@ -307,13 +312,15 @@ def irfft2_non_uniform( domain = (domain0, domain1) m = jnp.fft.fftfreq(M, d=np.diff(domain[idx[0]]) / (2 * jnp.pi) / M) n = jnp.fft.rfftfreq(N, d=np.diff(domain[idx[1]]) / (2 * jnp.pi) / N) - xq = xq - jnp.array([domain0[0], domain1[0]]) + xq0 = xq0 - domain0[0] + xq1 = xq1 - domain1[0] + xq = (xq0, xq1) basis = jnp.exp( 1j * ( - (m * xq[..., idx[0], jnp.newaxis])[..., jnp.newaxis] - + (n * xq[..., idx[1], jnp.newaxis])[..., jnp.newaxis, :] + (m * xq[idx[0]][..., jnp.newaxis])[..., jnp.newaxis] + + (n * xq[idx[1]][..., jnp.newaxis])[..., jnp.newaxis, :] ) ) fq = 2.0 * (basis * a).real.sum(axis=(-2, -1)) diff --git a/tests/test_integrals.py b/tests/test_integrals.py index 0aabb9fbe..a04b391c2 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1625,6 +1625,26 @@ def test_bounce2d_checks(self): # 10. Plotting fig, ax = bounce.plot(l, pitch_inv[l], include_legend=False, show=False) + + length = bounce.compute_length() + np.testing.assert_allclose( + length, + # Computed through with Simpson's rule at 800 nodes (over-resolved). + # The difference is likely not due to quadrature error, rather interpolation + # error as the data points for ``bounce.compute_length()`` come from Fourier + # series of |B|, while those for come from Fourier series of + # plasma boundary. Also, currently JAX has a bug with DCT and FFT that limit + # the accuracy to something comparable to 32 bit. + [ + 384.77892007, + 361.60220181, + 345.33817065, + 333.00781712, + 352.16277188, + 440.09424799, + ], + rtol=3e-3, + ) return fig @staticmethod diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py index fcc0ea90e..e0843acd7 100644 --- a/tests/test_interp_utils.py +++ b/tests/test_interp_utils.py @@ -244,18 +244,19 @@ def test_interp_rfft(self, func, n, domain): ) def test_interp_rfft2(self, func, m, n, domain0, domain1): """Test non-uniform FFT interpolation.""" - xq = np.array([[7.34, 1.10134, 2.28, 1e3 * np.e], [1.1, 3.78432, 8.542, 0]]).T + theta = np.array([7.34, 1.10134, 2.28, 1e3 * np.e]) + zeta = np.array([1.1, 3.78432, 8.542, 0]) x = np.linspace(domain0[0], domain0[1], m, endpoint=False) y = np.linspace(domain1[0], domain1[1], n, endpoint=False) x, y = map(np.ravel, list(np.meshgrid(x, y, indexing="ij"))) - truth = func(xq[..., 0], xq[..., 1]) + truth = func(theta, zeta) f = func(x, y).reshape(m, n) np.testing.assert_allclose( - interp_rfft2(xq, f, domain0, domain1, axes=(-2, -1)), + interp_rfft2(theta, zeta, f, domain0, domain1, axes=(-2, -1)), truth, ) np.testing.assert_allclose( - interp_rfft2(xq, f, domain0, domain1, axes=(-1, -2)), + interp_rfft2(theta, zeta, f, domain0, domain1, axes=(-1, -2)), truth, )