Skip to content

Commit

Permalink
Make pitch optional argument for plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Aug 30, 2024
1 parent 8edc317 commit 75c13fd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
50 changes: 25 additions & 25 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def points(self, pitch_inv, num_well=None):
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integral at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce points at each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
num_well : int or None
Expand Down Expand Up @@ -251,7 +251,7 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs):
epigraph of |B|.
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integral at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce points at each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
plot : bool
Expand Down Expand Up @@ -298,7 +298,7 @@ def integrate(
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integral at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce integrals of each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
integrand : callable
Expand Down Expand Up @@ -376,18 +376,18 @@ def integrate(
assert result.shape[-1] == setdefault(num_well, np.prod(self._dB_dz.shape[-2:]))
return result

def plot(self, pitch_inv, m, l, **kwargs):
def plot(self, m, l, pitch_inv=None, **kwargs):
"""Plot the field line and bounce points of the given pitch angles.
Parameters
----------
pitch_inv : jnp.ndarray
Shape (P, ).
1/λ values to evaluate the bounce integral at the field line
specified by Clebsch coordinate α(m), ρ(l).
m, l : int, int
Indices into the nodes of the grid supplied to make this object.
``alpha, rho = grid.meshgrid_reshape(grid.nodes[:, :2], "arz")[m, l, 0]``.
``alpha,rho=grid.meshgrid_reshape(grid.nodes[:,:2],"arz")[m,l,0]``.
pitch_inv : jnp.ndarray
Shape (P, ).
Optional, 1/λ values whose corresponding bounce points on the field line
specified by Clebsch coordinate α(m), ρ(l) will be plotted.
kwargs
Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``.
Expand All @@ -397,22 +397,22 @@ def plot(self, pitch_inv, m, l, **kwargs):
Matplotlib (fig, ax) tuple.
"""
pitch_inv = jnp.atleast_1d(jnp.squeeze(pitch_inv))
errorif(
pitch_inv.ndim != 1,
msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.",
)
z1, z2 = bounce_points(
pitch_inv[:, jnp.newaxis, jnp.newaxis],
self._zeta,
self.B[m, l],
self._dB_dz[m, l],
)
if pitch_inv is not None:
pitch_inv = jnp.atleast_1d(jnp.squeeze(pitch_inv))
errorif(
pitch_inv.ndim != 1,
msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.",
)
z1, z2 = bounce_points(
pitch_inv[:, jnp.newaxis, jnp.newaxis],
self._zeta,
self.B[m, l],
self._dB_dz[m, l],
)
kwargs["z1"] = z1
kwargs["z2"] = z2
kwargs["k"] = pitch_inv
fig, ax = plot_ppoly(
ppoly=PPoly(self.B[m, l].T, self._zeta),
z1=z1,
z2=z2,
k=pitch_inv,
**_set_default_plot_kwargs(kwargs),
PPoly(self.B[m, l].T, self._zeta), **_set_default_plot_kwargs(kwargs)
)
return fig, ax
6 changes: 3 additions & 3 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _check_spline_shape(knots, g, dg_dz, pitch_inv=None):
last axis enumerates the polynomials that compose a particular spline.
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integral at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce integrals of each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
Expand Down Expand Up @@ -127,7 +127,7 @@ def bounce_points(
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integral at each field line. 1/λ(ρ,α) is
1/λ values to compute the bounce points at each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
knots : jnp.ndarray
Expand Down Expand Up @@ -321,7 +321,7 @@ def bounce_quadrature(
epigraph of |B|.
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to evaluate the bounce integral at each field line. 1/λ(ρ,α) is
1/λ values to evaluate the bounce integrals of each field line. 1/λ(ρ,α) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that corresponds to that field line.
integrand : callable
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def test_bounce1d_checks(self):
print("(α, ρ):", nodes[m, l, 0])

# 7. Plotting
fig, ax = bounce.plot(pitch_inv[..., l], m, l, include_legend=False, show=False)
fig, ax = bounce.plot(m, l, pitch_inv[..., l], include_legend=False, show=False)
return fig

@pytest.mark.unit
Expand Down

0 comments on commit 75c13fd

Please sign in to comment.