Skip to content

Commit

Permalink
Add options to plot 3d without any axis visible (#1289)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Oct 3, 2024
2 parents b2b536c + 459e7d0 commit e35436d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 8 deletions.
68 changes: 61 additions & 7 deletions desc/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,15 @@ def plot_3d(
* ``cmap``: string denoting colormap to use.
* ``levels``: array of data values where ticks on colorbar should be placed.
* ``alpha``: float in [0,1.0], the transparency of the plotted surface
* ``showscale``: Bool, whether or not to show the colorbar. True by default
* ``showscale``: Bool, whether or not to show the colorbar. True by default.
* ``showgrid``: Bool, whether or not to show the coordinate grid lines.
True by default.
* ``showticklabels``: Bool, whether or not to show the coordinate tick labels.
True by default.
* ``showaxislabels``: Bool, whether or not to show the coordinate axis labels.
True by default.
* ``zeroline``: Bool, whether or not to show the zero coordinate axis lines.
True by default.
* ``field``: MagneticField, a magnetic field with which to calculate Bn on
the surface, must be provided if Bn is entered as the variable to plot.
* ``field_grid``: MagneticField, a Grid to pass to the field as a source grid
Expand Down Expand Up @@ -956,6 +964,10 @@ def plot_3d(
data = data.reshape((grid.num_theta, grid.num_rho, grid.num_zeta), order="F")

label = r"$\mathbf{B} \cdot \hat{n} ~(\mathrm{T})$"
showgrid = kwargs.pop("showgrid", True)
zeroline = kwargs.pop("zeroline", True)
showticklabels = kwargs.pop("showticklabels", True)
showaxislabels = kwargs.pop("showaxislabels", True)
errorif(
len(kwargs) != 0,
ValueError,
Expand Down Expand Up @@ -1024,30 +1036,48 @@ def plot_3d(
if fig is None:
fig = go.Figure()
fig.add_trace(meshdata)
xaxis_title = (
LatexNodes2Text().latex_to_text(_AXIS_LABELS_XYZ[0]) if showaxislabels else ""
)
yaxis_title = (
LatexNodes2Text().latex_to_text(_AXIS_LABELS_XYZ[1]) if showaxislabels else ""
)
zaxis_title = (
LatexNodes2Text().latex_to_text(_AXIS_LABELS_XYZ[2]) if showaxislabels else ""
)

fig.update_layout(
scene=dict(
xaxis_title=LatexNodes2Text().latex_to_text(_AXIS_LABELS_XYZ[0]),
yaxis_title=LatexNodes2Text().latex_to_text(_AXIS_LABELS_XYZ[1]),
zaxis_title=LatexNodes2Text().latex_to_text(_AXIS_LABELS_XYZ[2]),
xaxis_title=xaxis_title,
yaxis_title=yaxis_title,
zaxis_title=zaxis_title,
aspectmode="data",
xaxis=dict(
backgroundcolor="white",
gridcolor="darkgrey",
showbackground=False,
zerolinecolor="darkgrey",
showgrid=showgrid,
zeroline=zeroline,
showticklabels=showticklabels,
),
yaxis=dict(
backgroundcolor="white",
gridcolor="darkgrey",
showbackground=False,
zerolinecolor="darkgrey",
showgrid=showgrid,
zeroline=zeroline,
showticklabels=showticklabels,
),
zaxis=dict(
backgroundcolor="white",
gridcolor="darkgrey",
showbackground=False,
zerolinecolor="darkgrey",
showgrid=showgrid,
zeroline=zeroline,
showticklabels=showticklabels,
),
),
width=figsize[0] * dpi,
Expand Down Expand Up @@ -2414,6 +2444,14 @@ def plot_coils(coils, grid=None, fig=None, return_data=False, **kwargs):
* ``lw``: float, linewidth of plotted coils
* ``ls``: str, linestyle of plotted coils
* ``color``: str, color of plotted coils
* ``showgrid``: Bool, whether or not to show the coordinate grid lines.
True by default.
* ``showticklabels``: Bool, whether or not to show the coordinate tick labels.
True by default.
* ``showaxislabels``: Bool, whether or not to show the coordinate axis labels.
True by default.
* ``zeroline``: Bool, whether or not to show the zero coordinate axis lines.
True by default.
Returns
-------
Expand All @@ -2428,6 +2466,10 @@ def plot_coils(coils, grid=None, fig=None, return_data=False, **kwargs):
figsize = kwargs.pop("figsize", (10, 10))
color = kwargs.pop("color", "black")
unique = kwargs.pop("unique", False)
showgrid = kwargs.pop("showgrid", True)
zeroline = kwargs.pop("zeroline", True)
showticklabels = kwargs.pop("showticklabels", True)
showaxislabels = kwargs.pop("showaxislabels", True)
errorif(
len(kwargs) != 0,
ValueError,
Expand Down Expand Up @@ -2495,28 +2537,40 @@ def flatten_coils(coilset):
)

fig.add_trace(trace)
xaxis_title = "X (m)" if showaxislabels else ""
yaxis_title = "Y (m)" if showaxislabels else ""
zaxis_title = "Z (m)" if showaxislabels else ""
fig.update_layout(
scene=dict(
xaxis_title="X (m)",
yaxis_title="Y (m)",
zaxis_title="Z (m)",
xaxis_title=xaxis_title,
yaxis_title=yaxis_title,
zaxis_title=zaxis_title,
xaxis=dict(
backgroundcolor="white",
gridcolor="darkgrey",
showbackground=False,
zerolinecolor="darkgrey",
showgrid=showgrid,
zeroline=zeroline,
showticklabels=showticklabels,
),
yaxis=dict(
backgroundcolor="white",
gridcolor="darkgrey",
showbackground=False,
zerolinecolor="darkgrey",
showgrid=showgrid,
zeroline=zeroline,
showticklabels=showticklabels,
),
zaxis=dict(
backgroundcolor="white",
gridcolor="darkgrey",
showbackground=False,
zerolinecolor="darkgrey",
showgrid=showgrid,
zeroline=zeroline,
showticklabels=showticklabels,
),
aspectmode="data",
),
Expand Down
44 changes: 43 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,15 @@ def test_3d_rt(self):
def test_plot_3d_surface(self):
"""Test 3d plotting of surface object."""
surf = FourierRZToroidalSurface()
fig = plot_3d(surf, "curvature_H_rho")
fig = plot_3d(
surf,
"curvature_H_rho",
showgrid=False,
showscale=False,
zeroline=False,
showticklabels=False,
showaxislabels=False,
)
return fig

@pytest.mark.unit
Expand Down Expand Up @@ -851,6 +859,40 @@ def flatten_coils(coilset):
return fig


@pytest.mark.unit
def test_plot_coils_no_grid():
"""Test 3d plotting of coils with currents without any gridlines."""
N = 48
NFP = 4
I = 1
coil = FourierXYZCoil()
coil.rotate(angle=np.pi / N)
coils = CoilSet.linspaced_angular(coil, I, [0, 0, 1], np.pi / NFP, N // NFP // 2)
with pytest.raises(ValueError, match="Expected `coils`"):
plot_coils("not a coil")
fig, data = plot_coils(
coils,
unique=True,
return_data=True,
showgrid=False,
zeroline=False,
showticklabels=False,
showaxislabels=False,
)

def flatten_coils(coilset):
if hasattr(coilset, "__len__"):
return [a for i in coilset for a in flatten_coils(i)]
else:
return [coilset]

coil_list = flatten_coils(coils)
for string in ["X", "Y", "Z"]:
assert string in data.keys()
assert len(data[string]) == len(coil_list)
return fig


@pytest.mark.unit
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
def test_plot_b_mag():
Expand Down

0 comments on commit e35436d

Please sign in to comment.