Skip to content

Commit

Permalink
Merge branch 'ku/fourier_bounce' into ku/fourier_bounce_neo
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Oct 17, 2024
2 parents fd723a6 + aedb90f commit 8b5cdb8
Show file tree
Hide file tree
Showing 11 changed files with 1,025 additions and 447 deletions.
4 changes: 2 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from jax.lax import cond, fori_loop, scan, switch, while_loop
from jax.nn import softmax as softargmax
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.numpy.fft import ifft, irfft, irfft2, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
Expand Down Expand Up @@ -407,7 +407,7 @@ def tangent_solve(g, y):
jit = lambda func, *args, **kwargs: func
execute_on_cpu = lambda func: func
import scipy.optimize
from numpy.fft import irfft, rfft, rfft2 # noqa: F401
from numpy.fft import ifft, irfft, irfft2, rfft, rfft2 # noqa: F401
from scipy.fft import dct, idct # noqa: F401
from scipy.integrate import odeint # noqa: F401
from scipy.linalg import ( # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3285,7 +3285,7 @@ def _periodic_grad_alpha(params, transforms, profiles, data, **kwargs):
units_long="Inverse meters",
description=(
"Gradient of field line label, which is perpendicular to the field line, "
"periodic component"
"secular component"
),
dim=3,
params=[],
Expand Down
6 changes: 0 additions & 6 deletions desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,9 +1961,6 @@ def _gbdrift(params, transforms, profiles, data, **kwargs):

@register_compute_fun(
name="periodic(gbdrift)",
# Exact definition of the magnetic drifts taken from
# eqn. 48 of Introduction to Quasisymmetry by Landreman
# https://tinyurl.com/54udvaa4
label="\\mathrm{periodic}(\\nabla \\vert B \\vert)_{\\mathrm{drift}}",
units="1/(T-m^{2})",
units_long="inverse Tesla meters^2",
Expand All @@ -1987,9 +1984,6 @@ def _periodic_gbdrift(params, transforms, profiles, data, **kwargs):

@register_compute_fun(
name="secular(gbdrift)",
# Exact definition of the magnetic drifts taken from
# eqn. 48 of Introduction to Quasisymmetry by Landreman
# https://tinyurl.com/54udvaa4
label="\\mathrm{secular}(\\nabla \\vert B \\vert)_{\\mathrm{drift}}",
units="1/(T-m^{2})",
units_long="inverse Tesla meters^2",
Expand Down
145 changes: 90 additions & 55 deletions desc/integrals/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@
)


# TODO: Generalize this beyond ζ = ϕ or just map to Clebsch with ϕ.
def get_alpha(alpha_0, iota, num_transit, period):
"""Get sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) of field line.
Parameters
----------
alpha_0 : float
Starting field line poloidal label.
iota : jnp.ndarray
Shape (iota.size, ).
Rotational transform normalized by 2π.
num_transit : float
Number of ``period``s to follow field line.
period : float
Toroidal period after which to update label.
Returns
-------
alpha : jnp.ndarray
Shape (iota.size, num_transit).
Sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line.
"""
# Δϕ (∂α/∂ϕ) = Δϕ ι̅ = Δϕ ι/2π = Δϕ data["iota"]
alpha = alpha_0 + period * jnp.expand_dims(iota, -1) * jnp.arange(num_transit)
return alpha


@partial(jnp.vectorize, signature="(m),(m)->(m)")
def _in_epigraph_and(is_intersect, df_dy_sign, /):
"""Set and epigraph of function f with the given set of points.
Expand Down Expand Up @@ -101,12 +129,24 @@ def _chebcast(cheb, arr):


class FourierChebyshevSeries(IOAble):
"""Fourier-Chebyshev series.
"""Real-valued Fourier-Chebyshev series.
f(x, y) = ∑ₘₙ aₘₙ ψₘ(x) Tₙ(y)
where ψₘ are trigonometric polynomials on [0, 2π]
and Tₙ are Chebyshev polynomials on [−yₘᵢₙ, yₘₐₓ].
Examples
--------
Let the magnetic field be B = ∇ρ × ∇x. This basis will then parameterize
maps in Clebsch coordinates. Passing in a sequence of x values tracking
the field line (see ``get_alpha``) to the ``compute_cheb`` method will
generate a 1D parameterization of f along the field line.
This is useful to interpolate f ≝ θ and use the map x, ζ ↦ θ(x, ζ) to
compute quantities along field lines via evaluating Fourier series
parameterized in DESC computational coordinates θ, ζ, where the Fourier
transform is more condensed when NFP > 1.
Notes
-----
Performance may improve significantly
Expand All @@ -126,46 +166,41 @@ class FourierChebyshevSeries(IOAble):
Attributes
----------
M : int
X : int
Fourier spectral resolution.
N : int
Y : int
Chebyshev spectral resolution.
"""

def __init__(self, f, domain=(-1, 1), lobatto=False):
"""Interpolate Fourier-Chebyshev series to ``f``."""
self.M = f.shape[-2]
self.N = f.shape[-1]
self.X = f.shape[-2]
self.Y = f.shape[-1]
errorif(domain[0] > domain[-1], msg="Got inverted domain.")
self.domain = tuple(domain)
errorif(lobatto, NotImplementedError, "JAX hasn't implemented type 1 DCT.")
self.lobatto = bool(lobatto)
self._c = FourierChebyshevSeries._transform(f, self.lobatto)

@staticmethod
def _transform(f, lobatto):
N = f.shape[-1]
return rfft(
dct(f, type=2 - lobatto, axis=-1) / (N - lobatto),
self._c = rfft(
dct(f, type=2 - lobatto, axis=-1) / (self.Y - lobatto),
axis=-2,
norm="forward",
)

@staticmethod
def nodes(M, N, L=None, domain=(-1, 1), lobatto=False):
def nodes(X, Y, L=None, domain=(-1, 1), lobatto=False):
"""Tensor product grid of optimal collocation nodes for this basis.
Parameters
----------
M : int
X : int
Grid resolution in x direction. Preferably power of 2.
N : int
Y : int
Grid resolution in y direction. Preferably power of 2.
L : int or jnp.ndarray
Optional, resolution in radial direction of domain [0, 1].
May also be an array of coordinates values. If given, then the
returned ``coords`` is a 3D tensor-product with shape (L * M * N, 3).
returned ``coords`` is a 3D tensor-product with shape (L * X * Y, 3).
domain : tuple[float]
Domain for y coordinates. Default is [-1, 1].
lobatto : bool
Expand All @@ -175,12 +210,12 @@ def nodes(M, N, L=None, domain=(-1, 1), lobatto=False):
Returns
-------
coords : jnp.ndarray
Shape (M * N, 2).
Shape (X * Y, 2).
Grid of (x, y) points for optimal interpolation.
"""
x = fourier_pts(M)
y = cheb_pts(N, domain, lobatto)
x = fourier_pts(X)
y = cheb_pts(Y, domain, lobatto)
if L is None:
coords = (x, y)
else:
Expand All @@ -190,30 +225,30 @@ def nodes(M, N, L=None, domain=(-1, 1), lobatto=False):
coords = tuple(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij")))
return jnp.column_stack(coords)

def evaluate(self, M, N):
"""Evaluate Fourier-Chebyshev series.
def evaluate(self, X, Y):
"""Evaluate Fourier-Chebyshev series on tensor-product grid.
Parameters
----------
M : int
X : int
Grid resolution in x direction. Preferably power of 2.
N : int
Y : int
Grid resolution in y direction. Preferably power of 2.
Returns
-------
fq : jnp.ndarray
Shape (..., M, N)
Shape (..., X, Y)
Fourier-Chebyshev series evaluated at
``FourierChebyshevSeries.nodes(M,N,L,self.domain,self.lobatto)``.
``FourierChebyshevSeries.nodes(X,Y,L,self.domain,self.lobatto)``.
"""
return idct(
irfft(self._c, n=M, axis=-2, norm="forward"),
irfft(self._c, n=X, axis=-2, norm="forward"),
type=2 - self.lobatto,
n=N,
n=Y,
axis=-1,
) * (N - self.lobatto)
) * (Y - self.lobatto)

def harmonics(self):
"""Spectral coefficients aₘₙ of the interpolating trigonometric polynomial.
Expand All @@ -224,12 +259,12 @@ def harmonics(self):
Returns
-------
a_mn : jnp.ndarray
Shape (..., M, N).
Shape (..., X, Y).
Real valued spectral coefficients for Fourier-Chebyshev series.
"""
a_mn = harmonic(cheb_from_dct(self._c), self.M, axis=-2)
assert a_mn.shape[-2:] == (self.M, self.N)
a_mn = harmonic(cheb_from_dct(self._c), self.X, axis=-2)
assert a_mn.shape[-2:] == (self.X, self.Y)
return a_mn

def compute_cheb(self, x):
Expand All @@ -250,9 +285,9 @@ def compute_cheb(self, x):
x = jnp.atleast_1d(x)[..., jnp.newaxis]
# Add axis to broadcast against multiple x values.
cheb = cheb_from_dct(
irfft_non_uniform(x, self._c[..., jnp.newaxis, :, :], self.M, axis=-2)
irfft_non_uniform(x, self._c[..., jnp.newaxis, :, :], self.X, axis=-2)
)
assert cheb.shape[-2:] == (x.shape[-2], self.N)
assert cheb.shape[-2:] == (x.shape[-2], self.Y)
return PiecewiseChebyshevSeries(cheb, self.domain)


Expand All @@ -265,7 +300,7 @@ class PiecewiseChebyshevSeries(IOAble):
Parameters
----------
cheb : jnp.ndarray
Shape (..., M, N).
Shape (..., X, Y).
Chebyshev coefficients αₙ(x) for f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y).
domain : tuple[float]
Domain for y coordinates. Default is [-1, 1].
Expand All @@ -279,12 +314,12 @@ def __init__(self, cheb, domain=(-1, 1)):
self.domain = tuple(domain)

@property
def M(self):
def X(self):
"""Number of cuts."""
return self.cheb.shape[-2]

@property
def N(self):
def Y(self):
"""Chebyshev spectral resolution."""
return self.cheb.shape[-1]

Expand All @@ -297,26 +332,26 @@ def stitch(self):
dfx = f_1[..., :-1] - f_0[..., 1:] # Δf = f(xᵢ, y₁) - f(xᵢ₊₁, y₀)
self.cheb = self.cheb.at[..., 1:, 0].add(dfx.cumsum(axis=-1))

def evaluate(self, N):
"""Evaluate Chebyshev series at N Chebyshev points.
def evaluate(self, Y):
"""Evaluate Chebyshev series at Y Chebyshev points.
Evaluate each function in this set
{ fₓ | fₓ : y ↦ ∑ₙ₌₀ᴺ⁻¹ aₙ(x) Tₙ(y) }
at y points given by the N Chebyshev points.
at y points given by the Y Chebyshev points.
Parameters
----------
N : int
Y : int
Grid resolution in y direction. Preferably power of 2.
Returns
-------
fq : jnp.ndarray
Shape (..., M, N)
Chebyshev series evaluated at N Chebyshev points.
Shape (..., X, Y)
Chebyshev series evaluated at Y Chebyshev points.
"""
return idct(dct_from_cheb(self.cheb), type=2, n=N, axis=-1) * N
return idct(dct_from_cheb(self.cheb), type=2, n=Y, axis=-1) * Y

def isomorphism_to_C1(self, y):
"""Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ².
Expand Down Expand Up @@ -377,7 +412,7 @@ def eval1d(self, z, cheb=None):
along the second to last axis of ``cheb`` and ``z % domain`` is
the coordinate value on the domain of that Chebyshev series.
cheb : jnp.ndarray
Shape (..., M, N).
Shape (..., X, Y).
Chebyshev coefficients to use. If not given, uses ``self.cheb``.
Returns
Expand All @@ -392,7 +427,7 @@ def eval1d(self, z, cheb=None):
x_idx, y = self.isomorphism_to_C2(z)
y = bijection_to_disc(y, self.domain[0], self.domain[1])
# Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z])
# are held in cheb with shape (..., num cheb series, N).
# are held in cheb with shape (..., num cheb series, Y).
cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2)
f = idct_non_uniform(y, cheb, N)
assert f.shape == z.shape
Expand All @@ -412,7 +447,7 @@ def intersect2d(self, k=0.0, *, eps=_eps):
Returns
-------
y : jnp.ndarray
Shape (..., *cheb.shape[:-1], N - 1).
Shape (..., *cheb.shape[:-1], Y - 1).
Solutions yᵢ of f(x, yᵢ) = k(x), in ascending order.
is_intersect : jnp.ndarray
Shape y.shape.
Expand All @@ -425,7 +460,7 @@ def intersect2d(self, k=0.0, *, eps=_eps):
c = _subtract_first(_chebcast(self.cheb, k), k)
# roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x)
y = chebroots_vec(c)
assert y.shape == (*c.shape[:-1], self.N - 1)
assert y.shape == (*c.shape[:-1], self.Y - 1)

# Intersects must satisfy y ∈ [-1, 1].
# Pick sentinel such that only distinct roots are considered intersects.
Expand All @@ -435,8 +470,8 @@ def intersect2d(self, k=0.0, *, eps=_eps):
y = jnp.where(is_intersect, y.real, 0.0)

# TODO: Multipoint evaluation with FFT.
# Chapter 10, https://doi.org/10.1017/CBO9781139856065.
n = jnp.arange(self.N)
# See note in integrals/basis.py near line 145.
n = jnp.arange(self.Y)
# ∂f/∂y = ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n Uₙ₋₁(y)
# sign ∂f/∂y = sign ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n sin(n arcos y)
df_dy_sign = jnp.sign(
Expand Down Expand Up @@ -476,10 +511,10 @@ def intersect1d(self, k=0.0, *, num_intersect=None, pad_value=0.0):
"""
errorif(
self.N < 2,
self.Y < 2,
NotImplementedError,
"This method requires a Chebyshev spectral resolution of N > 1, "
f"but got N = {self.N}.",
"This method requires a Chebyshev spectral resolution of Y > 1, "
f"but got Y = {self.Y}.",
)

# Add axis to use same k over all Chebyshev series of the piecewise spline.
Expand Down Expand Up @@ -522,7 +557,7 @@ def _check_shape(self, z1, z2, k):
# Same but back dim already exists.
z1 = atleast_nd(self.cheb.ndim, z1)
z2 = atleast_nd(self.cheb.ndim, z2)
# Cheb has shape (..., M, N) and others
# Cheb has shape (..., X, Y) and others
# have shape (K, ..., W)
errorif(not (z1.ndim == z2.ndim == k.ndim == self.cheb.ndim))
return z1, z2, k
Expand Down Expand Up @@ -631,7 +666,7 @@ def plot1d(
Parameters
----------
cheb : jnp.ndarray
Shape (M, N).
Shape (X, Y).
Piecewise Chebyshev series f.
num : int
Number of points to evaluate ``cheb`` for plot.
Expand Down Expand Up @@ -669,7 +704,7 @@ def plot1d(
legend = {}
z = jnp.linspace(
start=self.domain[0],
stop=self.domain[0] + (self.domain[1] - self.domain[0]) * self.M,
stop=self.domain[0] + (self.domain[1] - self.domain[0]) * self.X,
num=num,
)
_add2legend(legend, ax.plot(z, self.eval1d(z, cheb), label=vlabel))
Expand Down
Loading

0 comments on commit 8b5cdb8

Please sign in to comment.