diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index 7a5c4d8cf..c14b94978 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -53,7 +53,9 @@ def check_points(self, points, pitch_inv, *, plot=True, **kwargs): """Check that bounce points are computed correctly.""" @abstractmethod - def integrate(self, integrand, pitch_inv, f=None, weight=None, points=None): + def integrate( + self, integrand, pitch_inv, f=None, weight=None, points=None, *, quad=None + ): """Bounce integrate ∫ f(λ, ℓ) dℓ.""" @@ -82,8 +84,9 @@ class Bounce2D(Bounce): the particle's guiding center trajectory traveling in the direction of increasing field-line-following coordinate ζ. - Brief description of algorithm - ------------------------------ + + Overview + -------- Magnetic field line with label α, defined by B = ∇ρ × ∇α, is determined from α : ρ, θ, ζ ↦ θ + λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)] Interpolate Fourier-Chebyshev series to DESC poloidal coordinate. @@ -102,8 +105,8 @@ class Bounce2D(Bounce): In that case, supply the single valued parts, which will be interpolated with FFTs, and use the provided coordinates θ,ζ ∈ ℝ to compose G. - Notes for developers - -------------------- + Notes + ----- For applications which reduce to computing a nonlinear function of distance along field lines between bounce points, it is required to identify these points with field-line-following coordinates. (In the special case of a linear @@ -234,8 +237,8 @@ class Bounce2D(Bounce): * 2D interpolation enables tracing the field line for many toroidal transits. * The drawback is that evaluating a Fourier series with resolution F at Q non-uniform quadrature points takes 𝒪([F+Q] log[F] log[1/ε]) time - whereas cubic splines take 𝒪(C Q) time. Still, F decreases as - NFP increases whereas C increases, and Q >> F and C. + whereas cubic splines take 𝒪(C Q) time. However, as NFP increases, + F decreases whereas C increases. Also, Q >> F and Q >> C. Attributes ---------- @@ -288,7 +291,7 @@ def __init__( theta : jnp.ndarray Shape (num rho, X, Y). DESC coordinates θ sourced from the Clebsch coordinates - ``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``. + ``FourierChebyshevSeries.nodes(X,Y,rho,domain=(0,2*jnp.pi))``. Y_B : int Desired Chebyshev spectral resolution for |B|. Default is to double the resolution of ``theta``. @@ -325,8 +328,11 @@ def __init__( Flag for debugging. Must be false for JAX transformations. spline : bool Whether to use cubic splines to compute bounce points. - Default is true. This is useful since the efficient root-finding - on Chebyshev series algorithm is not yet implemented. + Default is true, because the algorithm for efficient root-finding on + Chebyshev series algorithm is not yet implemented. + When using splines, it is recommended to reduce the ``num_well`` + parameter in the ``points`` method from ``3*Y_B*num_transit`` to + ``Y_B*num_transit``. """ errorif(grid.sym, NotImplementedError, msg="Need grid that works with FFTs.") @@ -413,9 +419,8 @@ def compute_theta(eq, X=16, Y=32, rho=1.0, clebsch=None, **kwargs): rho : float or jnp.ndarray Flux surfaces labels in [0, 1] on which to compute. clebsch : jnp.ndarray - Optional, Clebsch coordinate tensor-product grid (ρ, α, ζ). - ``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``. - If given, ``rho`` is ignored. + Optional, precomputed Clebsch coordinate tensor-product grid (ρ, α, ζ). + ``FourierChebyshevSeries.nodes(X,Y,rho,domain=(0,2*jnp.pi))``. kwargs Additional parameters to supply to the coordinate mapping function. See ``desc.equilibrium.Equilibrium.map_coordinates``. @@ -559,6 +564,7 @@ def integrate( *, check=False, plot=False, + quad=None, ): """Bounce integrate ∫ f(λ, ℓ) dℓ. @@ -607,6 +613,9 @@ def integrate( plot : bool Whether to plot the quantities in the integrand interpolated to the quadrature points of each integral. Ignored if ``check`` is false. + quad : tuple[jnp.ndarray] + Optional quadrature points and weights. If given this overrides + the quadrature chosen when this object was made. Returns ------- @@ -626,7 +635,15 @@ def integrate( pitch_inv = atleast_nd(self._c["T(z)"].cheb.ndim - 1, pitch_inv).T result = self._integrate( - integrand, pitch_inv, setdefault(f, []), z1, z2, check, plot + self._x if quad is None else quad[0], + self._w if quad is None else quad[1], + integrand, + pitch_inv, + setdefault(f, []), + z1, + z2, + check, + plot, ) if weight is not None: errorif( @@ -645,7 +662,7 @@ def integrate( ) return _swap_pl(result) - def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot): + def _integrate(self, x, w, integrand, pitch_inv, f, z1, z2, check, plot): """Bounce integrate ∫ f(λ, ℓ) dℓ. Parameters @@ -665,10 +682,10 @@ def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot): """ if not isinstance(f, (list, tuple)): f = [f] - shape = [*z1.shape, self._x.size] # num pitch, num rho, num well, num quad + shape = [*z1.shape, x.size] # num pitch, num rho, num well, num quad # ζ ∈ ℝ and θ ∈ ℝ coordinates of quadrature points zeta = flatten_matrix( - bijection_from_disc(self._x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]) + bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]) ) theta = self._c["T(z)"].eval1d(zeta) @@ -708,7 +725,7 @@ def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot): integrand(*f, B=B, pitch=1 / pitch_inv[..., jnp.newaxis], zeta=zeta) * B / B_sup_z - ).reshape(shape).dot(self._w) * grad_bijection_from_disc(z1, z2) + ).reshape(shape).dot(w) * grad_bijection_from_disc(z1, z2) if check: shape[-3], shape[0] = shape[0], shape[-3] @@ -722,7 +739,7 @@ def _integrate(self, integrand, pitch_inv, f, z1, z2, check, plot): return result - def compute_length(self, quad=None): + def compute_fieldline_length(self, quad=None): """Compute the proper length of the field line ∫ dℓ / |B|. Parameters @@ -730,7 +747,7 @@ def compute_length(self, quad=None): quad : tuple[jnp.ndarray] Quadrature points xₖ and weights wₖ for the approximate evaluation of the integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). - Number of points equal to half the Chebyshev resolution of |B| works well. + Resolution equal to half the Chebyshev resolution of |B| works well. Returns ------- @@ -876,8 +893,8 @@ class Bounce1D(Bounce): the particle's guiding center trajectory traveling in the direction of increasing field-line-following coordinate ζ. - Notes for developers - -------------------- + Notes + ----- For applications which reduce to computing a nonlinear function of distance along field lines between bounce points, it is required to identify these points with field-line-following coordinates. (In the special case of a linear diff --git a/tests/test_integrals.py b/tests/test_integrals.py index f8fec9b62..63f4d820f 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1060,7 +1060,6 @@ def test_quad_compare(self, is_strong, B): """Compare quadratures in W-shaped wells.""" x = np.linspace(-1, 1, 1000) plt.plot(x, B(x)) - plt.show() def func(x): w1 = jnp.sqrt(jnp.clip(2 - B(x), 0, jnp.inf)) @@ -1069,7 +1068,6 @@ def func(x): return w1 plt.plot(x, func(x)) - plt.show() truth, info = quadax.quadts(func, interval=(-1, 1)) print("\n" + 50 * "---" + f"\nTrue value: {truth}, neval: {info[1]}") @@ -1760,8 +1758,8 @@ def test_bounce2d_checks(self): print("ρ:", rho[l]) np.testing.assert_allclose( - bounce.compute_length(), - # Computed data below through with Simpson's rule at 800 nodes. + bounce.compute_fieldline_length(), + # Computed below through "fieldline length" with Simpson's rule 800 points. # The difference is likely due to interpolation and floating point error. # (On the version of JAX on which rtol was set, there is a bug with DCT # and FFT that limit the accuracy to something comparable to 32 bit).