Skip to content

Commit

Permalink
Debugging fourier bounce stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Aug 26, 2024
1 parent 540d062 commit 04f87a3
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 217 deletions.
41 changes: 32 additions & 9 deletions desc/integrals/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ def _subtract(c, k):
def _in_epigraph_and(is_intersect, df_dy_sign):
"""Set and epigraph of function f with the given set of points.
Return only intersects where there is a connected path between
adjacent intersects in the epigraph of a continuous map ``f``.
Return only intersects where the straight line path between adjacent
intersects resides in the epigraph of a continuous map ``f``.
Warnings
--------
Does not support keyword arguments.
Parameters
----------
Expand Down Expand Up @@ -196,7 +200,7 @@ def nodes(M, N, L=None, domain=(-1, 1), lobatto=False):
if L is not None:
if isposint(L):
L = jnp.flipud(jnp.linspace(1, 0, L, endpoint=False))
coords = (L, x, y)
coords = (jnp.atleast_1d(L), x, y)

Check warning on line 203 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L198-L203

Added lines #L198 - L203 were not covered by tests
else:
coords = (x, y)
coords = list(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij")))
Expand Down Expand Up @@ -388,8 +392,9 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0):
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape broadcasts with (..., *self.cheb.shape[:-2], num_intersect).
``z1``, ``z2`` holds intersects satisfying ∂f/∂y <= 0, ∂f/∂y >= 0,
respectively. The points are ordered such that the path between
``z1`` and ``z2`` lies in the epigraph of f.
respectively. The points are grouped and ordered such that the
straight line path between the intersects in ``z1`` and ``z2``
resides in the epigraph of f.
"""
errorif(

Check warning on line 400 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L400

Added line #L400 was not covered by tests
Expand Down Expand Up @@ -602,7 +607,7 @@ def plot1d(
klabel=r"$k$",
title=r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$",
hlabel=r"$z$",
vlabel=r"$f(z)$",
vlabel=r"$f$",
show=True,
):
"""Plot the piecewise Chebyshev series.
Expand Down Expand Up @@ -660,7 +665,7 @@ def plot1d(
)
ax.set_xlabel(hlabel)
ax.set_ylabel(vlabel)
ax.legend(legend.values(), legend.keys())
ax.legend(legend.values(), legend.keys(), loc="lower right")
ax.set_title(title)
plt.tight_layout()
if show:
Expand Down Expand Up @@ -698,5 +703,23 @@ def _plot_intersect(ax, legend, z1, z2, k, k_transparency, klabel):
mask = (z1 - z2) != 0.0
_z1 = z1[mask]
_z2 = z2[mask]
ax.scatter(_z1, jnp.full_like(_z1, k[i]), marker="v", color="tab:red")
ax.scatter(_z2, jnp.full_like(_z2, k[i]), marker="^", color="tab:green")
_add2legend(

Check warning on line 706 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L700-L706

Added lines #L700 - L706 were not covered by tests
legend,
ax.scatter(
_z1,
jnp.full_like(_z1, k[i]),
marker="v",
color="tab:red",
label=r"$z_1$",
),
)
_add2legend(

Check warning on line 716 in desc/integrals/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/basis.py#L716

Added line #L716 was not covered by tests
legend,
ax.scatter(
_z2,
jnp.full_like(_z2, k[i]),
marker="^",
color="tab:green",
label=r"$z_2$",
),
)
82 changes: 51 additions & 31 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,26 +415,46 @@ def bounce_points(self, pitch, num_well=None):
Returns
-------
bp1, bp2 : (jnp.ndarray, jnp.ndarray)
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape (P, L, num_well).
ζ coordinates of bounce points.
The pairs ``bp1`` and ``bp2`` form left and right integration boundaries,
respectively, for the bounce integrals.
ζ coordinates of bounce points. The points are grouped and ordered such
that the straight line path between the intersects in ``z1`` and ``z2``
resides in the epigraph of |B|.
"""
return self._B.intersect1d(1 / jnp.atleast_2d(pitch), num_well)

Check warning on line 425 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L425

Added line #L425 was not covered by tests

def check_bounce_points(self, bp1, bp2, pitch, plot=True, **kwargs):
"""Check that bounce points are computed correctly and plot them."""
def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs):
"""Check that bounce points are computed correctly.
Parameters
----------
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape (P, L, num_well).
ζ coordinates of bounce points. The points are grouped and ordered such
that the straight line path between the intersects in ``z1`` and ``z2``
resides in the epigraph of |B|.
pitch : jnp.ndarray
Shape (P, L).
λ values to evaluate the bounce integral at each field line. λ(ρ) is
specified by ``pitch[...,ρ]`` where in the latter the labels ρ are
interpreted as the index into the last axis that corresponds to that field
line. If two-dimensional, the first axis is the batch axis.
plot : bool
Whether to plot stuff.
kwargs : dict
Keyword arguments into ``ChebyshevBasisSet.plot1d``.
"""
kwargs.setdefault(

Check warning on line 449 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L449

Added line #L449 was not covered by tests
"title",
r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. "
r"$\vert B \vert(\zeta) = 1/\lambda$",
)
kwargs.setdefault("klabel", r"$1/\lambda$")
kwargs.setdefault("hlabel", r"$\zeta$")
kwargs.setdefault("vlabel", r"$\vert B \vert(\zeta)$")
self._B.check_intersect1d(bp1, bp2, 1 / pitch, plot, **kwargs)
kwargs.setdefault("vlabel", r"$\vert B \vert$")
self._B.check_intersect1d(z1, z2, 1 / pitch, plot, **kwargs)

Check warning on line 457 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L454-L457

Added lines #L454 - L457 were not covered by tests

def integrate(self, pitch, integrand, f, weight=None, num_well=None):
"""Bounce integrate ∫ f(ℓ) dℓ.
Expand Down Expand Up @@ -487,21 +507,21 @@ def integrate(self, pitch, integrand, f, weight=None, num_well=None):
"""
pitch = jnp.atleast_2d(pitch)
bp1, bp2 = self.bounce_points(pitch, num_well)
result = self._integrate(bp1, bp2, pitch, integrand, f)
z1, z2 = self.bounce_points(pitch, num_well)
result = self._integrate(z1, z2, pitch, integrand, f)
errorif(weight is not None, NotImplementedError)
return result

Check warning on line 513 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L509-L513

Added lines #L509 - L513 were not covered by tests

def _integrate(self, bp1, bp2, pitch, integrand, f):
assert bp1.ndim == 3
assert bp1.shape == bp2.shape
def _integrate(self, z1, z2, pitch, integrand, f):
assert z1.ndim == 3
assert z1.shape == z2.shape
assert pitch.ndim == 2
W = bp1.shape[-1] # number of wells
W = z1.shape[-1] # number of wells
shape = (pitch.shape[0], self._L, W, self._x.size)

Check warning on line 520 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L516-L520

Added lines #L516 - L520 were not covered by tests

# quadrature points parameterized by ζ for each pitch and flux surface
Q_zeta = flatten_matrix(

Check warning on line 523 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L523

Added line #L523 was not covered by tests
bijection_from_disc(self._x, bp1[..., jnp.newaxis], bp2[..., jnp.newaxis])
bijection_from_disc(self._x, z1[..., jnp.newaxis], z2[..., jnp.newaxis])
)
# quadrature points in (θ, ζ) coordinates
Q = jnp.stack([self._T.eval1d(Q_zeta), Q_zeta], axis=-1)

Check warning on line 527 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L527

Added line #L527 was not covered by tests
Expand Down Expand Up @@ -728,11 +748,11 @@ def bounce_points(self, pitch, num_well=None):
Returns
-------
bp1, bp2 : (jnp.ndarray, jnp.ndarray)
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape (P, L * M, num_well).
ζ coordinates of bounce points.
The pairs ``bp1`` and ``bp2`` form left and right integration boundaries,
respectively, for the bounce integrals.
ζ coordinates of bounce points. The points are grouped and ordered such
that the straight line path between the intersects in ``z1`` and ``z2``
resides in the epigraph of |B|.
If there were less than ``num_wells`` wells detected along a field line,
then the last axis, which enumerates bounce points for a particular field
Expand All @@ -747,16 +767,16 @@ def bounce_points(self, pitch, num_well=None):
num_well=num_well,
)

def check_bounce_points(self, bp1, bp2, pitch, plot=True, **kwargs):
def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs):
"""Check that bounce points are computed correctly.
Parameters
----------
bp1, bp2 : (jnp.ndarray, jnp.ndarray)
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape (P, L * M, num_well).
ζ coordinates of bounce points.
The pairs ``bp1`` and ``bp2`` form left and right integration boundaries,
respectively, for the bounce integrals.
ζ coordinates of bounce points. The points are grouped and ordered such
that the straight line path between the intersects in ``z1`` and ``z2``
resides in the epigraph of |B|.
pitch : jnp.ndarray
Shape must broadcast with (P, L * M).
λ values to evaluate the bounce integral at each field line. λ(ρ,α) is
Expand All @@ -770,8 +790,8 @@ def check_bounce_points(self, bp1, bp2, pitch, plot=True, **kwargs):
"""
_check_bounce_points(

Check warning on line 792 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L792

Added line #L792 was not covered by tests
bp1=bp1,
bp2=bp2,
z1=z1,
z2=z2,
pitch=jnp.atleast_2d(pitch),
knots=self._zeta,
B=self.B,
Expand Down Expand Up @@ -848,12 +868,12 @@ def integrate(
"""
pitch = jnp.atleast_2d(pitch)
bp1, bp2 = self.bounce_points(pitch, num_well)
z1, z2 = self.bounce_points(pitch, num_well)
result = bounce_quadrature(

Check warning on line 872 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L870-L872

Added lines #L870 - L872 were not covered by tests
x=self._x,
w=self._w,
bp1=bp1,
bp2=bp2,
z1=z1,
z2=z2,
pitch=pitch,
integrand=integrand,
f=f,
Expand All @@ -866,8 +886,8 @@ def integrate(
if weight is not None:
result *= interp_to_argmin_B_soft(

Check warning on line 887 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L886-L887

Added lines #L886 - L887 were not covered by tests
g=weight,
bp1=bp1,
bp2=bp2,
z1=z1,
z2=z2,
knots=self._zeta,
B=self.B,
dB_dz=self._dB_dz,
Expand Down
Loading

0 comments on commit 04f87a3

Please sign in to comment.