Skip to content

Commit

Permalink
simplify fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Aug 22, 2024
1 parent b7f343d commit 7b6b5a3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 52 deletions.
68 changes: 18 additions & 50 deletions desc/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,12 +2031,8 @@ def plot_boundaries(
):
"""Plot stellarator boundaries at multiple toroidal coordinates.
NOTE: supplied objects must have either all the same NFP, or if
there are differing NFPs, the non-axisymmetric objects must have the
same NFP and the rest of the objects must be axisymmetric. i.e
can plot a tokamak and an NFP=2 stellarator, but cannot plot
a NFP=2 and NFP=3 stellarator as there is some ambiguity on the
choice of phi
NOTE: If attempting to plot objects with differing NFP, `phi` must
be given explicitly.
Parameters
----------
Expand Down Expand Up @@ -2093,29 +2089,17 @@ def plot_boundaries(
fig, ax = plot_boundaries((eq1, eq2, eq3))
"""
NFPs = np.array([thing.NFP for thing in eqs])
Ns = np.array([thing.N for thing in eqs])
if not np.allclose(NFPs, NFPs[0]) and np.any(Ns == 0):
# if all NFPs are not equal, maybe there are some axisymmetric
# objects. We can try to change those to match the NFP of the first
# of the nonaxisymmetric objects
eqs = [thing.copy() for thing in eqs] # make copy so we dont modify originals
NFP_nonax = int(NFPs[NFPs > 1][0])
[
thing.change_resolution(NFP=NFP_nonax if thing.N == 0 else thing.NFP)
for thing in eqs
]
# if after above, the NFPs are still not all equal, means there are multiple
# nonaxisymmetric objects with differing NFPs, which it is not clear
# how to choose the phis for by default, so we will throw an error.
# if NFPs are not all equal, means there are
# objects with differing NFPs, which it is not clear
# how to choose the phis for by default, so we will throw an error
# unless phi was given.
phi = parse_argname_change(phi, kwargs, "zeta", "phi")
errorif(

Check warning on line 2097 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2097

Added line #L2097 was not covered by tests
not np.allclose([thing.NFP for thing in eqs], eqs[0].NFP),
not np.allclose([thing.NFP for thing in eqs], eqs[0].NFP) and phi is None,
ValueError,
"supplied objects must have the same number of field periods, "
"or if there are differing field periods, the ones which differ must be"
" axisymmetric.",
"or if there are differing field periods, `phi` must be given explicitly.",
)
phi = parse_argname_change(phi, kwargs, "zeta", "phi")

figsize = kwargs.pop("figsize", None)
cmap = kwargs.pop("cmap", "rainbow")
Expand Down Expand Up @@ -2228,12 +2212,8 @@ def plot_comparison(
):
"""Plot comparison between flux surfaces of multiple equilibria.
NOTE: supplied objects must have either all the same NFP, or if
there are differing NFPs, the non-axisymmetric objects must have the
same NFP and the rest of the objects must be axisymmetric. i.e
can plot a tokamak and an NFP=2 stellarator, but cannot plot
a NFP=2 and NFP=3 stellarator as there is some ambiguity on the
choice of phi
NOTE: If attempting to plot objects with differing NFP, `phi` must
be given explicitly.
Parameters
----------
Expand Down Expand Up @@ -2303,29 +2283,17 @@ def plot_comparison(
)
"""
NFPs = np.array([thing.NFP for thing in eqs])
Ns = np.array([thing.N for thing in eqs])
if not np.allclose(NFPs, NFPs[0]) and np.any(Ns == 0):
# if all NFPs are not equal, maybe there are some axisymmetric
# objects. We can try to change those to match the NFP of the first
# of the nonaxisymmetric objects
eqs = [thing.copy() for thing in eqs] # make copy so we dont modify originals
NFP_nonax = int(NFPs[NFPs > 1][0])
[
thing.change_resolution(NFP=NFP_nonax if thing.N == 0 else thing.NFP)
for thing in eqs
]
# if after above, the NFPs are still not all equal, means there are multiple
# nonaxisymmetric objects with differing NFPs, which it is not clear
# how to choose the phis for by default, so we will throw an error.
# if NFPs are not all equal, means there are
# objects with differing NFPs, which it is not clear
# how to choose the phis for by default, so we will throw an error
# unless phi was given.
phi = parse_argname_change(phi, kwargs, "zeta", "phi")
errorif(
not np.allclose([thing.NFP for thing in eqs], eqs[0].NFP),
not np.allclose([thing.NFP for thing in eqs], eqs[0].NFP) and phi is None,
ValueError,
"supplied objects must have the same number of field periods, "
"or if there are differing field periods, the ones which differ must be"
" axisymmetric.",
"or if there are differing field periods, `phi` must be given explicitly.",
)
phi = parse_argname_change(phi, kwargs, "zeta", "phi")
color = parse_argname_change(color, kwargs, "colors", "color")
ls = parse_argname_change(ls, kwargs, "linestyles", "ls")
lw = parse_argname_change(lw, kwargs, "lws", "lw")
Expand Down
12 changes: 10 additions & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,11 @@ def test_plot_boundaries(self):
eq4 = get("ESTELL")
with pytest.raises(ValueError, match="differing field periods"):
fig, ax = plot_boundaries([eq3, eq4], theta=0)
fig, ax, data = plot_boundaries((eq1, eq2, eq3), return_data=True)

This comment has been minimized.

Copy link
@missing-user

missing-user Sep 30, 2024

Contributor

The default behavior of plot_boundaries is phi=4 which triggers this code branch:

if isinstance(phi, numbers.Integral):
        phi = np.linspace(
            0, 2 * np.pi / eqs[-1].NFP, phi + 1
        )  # +1 to include pi and 2pi

This produces a phi array with 4+1 elements, of which only the first 4 are drawn. The new function call has 4 phi elements -> only 3 slices are drawn.

fig, ax, data = plot_boundaries(
(eq1, eq2, eq3),
phi=np.linspace(0, 2 * np.pi / eq3.NFP, 4, endpoint=False),
return_data=True,
)
assert "R" in data.keys()
assert "Z" in data.keys()
assert len(data["R"]) == 3
Expand Down Expand Up @@ -563,7 +567,11 @@ def test_plot_comparison_different_NFPs(self):
eq_nonax2 = get("ESTELL")
with pytest.raises(ValueError, match="differing field periods"):
fig, ax = plot_comparison([eq_nonax, eq_nonax2], theta=0)
fig, ax = plot_comparison([eq, eq_nonax], theta=0)
fig, ax = plot_comparison(
[eq, eq_nonax],
phi=np.linspace(0, 2 * np.pi / eq_nonax.NFP, 6, endpoint=False),
theta=0,
)
return fig


Expand Down

0 comments on commit 7b6b5a3

Please sign in to comment.