Skip to content

Commit

Permalink
Simplify boozer mode plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Aug 26, 2024
1 parent 2085630 commit 36bb102
Showing 1 changed file with 95 additions and 122 deletions.
217 changes: 95 additions & 122 deletions desc/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ def plot_boozer_modes( # noqa: C901
elif np.isscalar(rho) and rho > 1:
rho = np.linspace(1, 0, num=rho, endpoint=False)

B_mn = np.array([[]])
rho = np.sort(rho)

Check warning on line 2579 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2579

Added line #L2579 was not covered by tests
M_booz = kwargs.pop("M_booz", 2 * eq.M)
N_booz = kwargs.pop("N_booz", 2 * eq.N)
linestyle = kwargs.pop("ls", "-")
Expand All @@ -2594,16 +2594,15 @@ def plot_boozer_modes( # noqa: C901
else:
matrix, modes = ptolemy_linear_transform(basis.modes)

for i, r in enumerate(rho):
grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r))
transforms = get_transforms(
"|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = eq.compute("|B|_mn", grid=grid, transforms=transforms)
b_mn = np.atleast_2d(matrix @ data["|B|_mn"])
B_mn = np.vstack((B_mn, b_mn)) if B_mn.size else b_mn
grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho)
transforms = get_transforms(

Check warning on line 2598 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2597-L2598

Added lines #L2597 - L2598 were not covered by tests
"|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = eq.compute("|B|_mn", grid=grid, transforms=transforms)
B_mn = data["|B|_mn"].reshape((len(rho), -1))
B_mn = np.atleast_2d(matrix @ B_mn.T).T

Check warning on line 2605 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2601-L2605

Added lines #L2601 - L2605 were not covered by tests

zidx = np.where((modes[:, 1:] == np.array([[0, 0]])).all(axis=1))[0]
if norm:
Expand Down Expand Up @@ -2972,6 +2971,7 @@ def plot_qs_error( # noqa: 16 fxn too complex
rho = np.linspace(1, 0, num=20, endpoint=False)
elif np.isscalar(rho) and rho > 1:
rho = np.linspace(1, 0, num=rho, endpoint=False)
rho = np.sort(rho)

Check warning on line 2974 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2974

Added line #L2974 was not covered by tests

fig, ax = _format_ax(ax, figsize=kwargs.pop("figsize", None))

Expand All @@ -2989,119 +2989,92 @@ def plot_qs_error( # noqa: 16 fxn too complex
R0 = data["R0"]
B0 = np.mean(data["|B|"] * data["sqrt(g)"]) / np.mean(data["sqrt(g)"])

f_B = np.array([])
f_C = np.array([])
f_T = np.array([])
plot_data = {}
for i, r in enumerate(rho):
grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r))
if fB:
transforms = get_transforms(
"|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz
)
if i == 0: # only need to do this once for the first rho surface
matrix, modes, idx = ptolemy_linear_transform(
transforms["B"].basis.modes,
helicity=helicity,
NFP=transforms["B"].basis.NFP,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = eq.compute(
["|B|_mn", "B modes"], grid=grid, transforms=transforms
)
B_mn = matrix @ data["|B|_mn"]
f_b = np.sqrt(np.sum(B_mn[idx] ** 2)) / np.sqrt(np.sum(B_mn**2))
f_B = np.append(f_B, f_b)
if fC:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = eq.compute("f_C", grid=grid, helicity=helicity)
f_c = (
np.mean(np.abs(data["f_C"]) * data["sqrt(g)"])
/ np.mean(data["sqrt(g)"])
/ B0**3
)
f_C = np.append(f_C, f_c)
if fT:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = eq.compute("f_T", grid=grid)
f_t = (
np.mean(np.abs(data["f_T"]) * data["sqrt(g)"])
/ np.mean(data["sqrt(g)"])
* R0**2
/ B0**4
)
f_T = np.append(f_T, f_t)
plot_data = {"rho": rho}

Check warning on line 2992 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2992

Added line #L2992 was not covered by tests

plot_data["f_B"] = f_B
plot_data["f_C"] = f_C
plot_data["f_T"] = f_T
plot_data["rho"] = rho
grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho)
names = []
if fB:
names += ["|B|_mn"]
transforms = get_transforms(

Check warning on line 2998 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L2994-L2998

Added lines #L2994 - L2998 were not covered by tests
"|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz
)
matrix, modes, idx = ptolemy_linear_transform(

Check warning on line 3001 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3001

Added line #L3001 was not covered by tests
transforms["B"].basis.modes,
helicity=helicity,
NFP=transforms["B"].basis.NFP,
)
if fC or fT:
names += ["sqrt(g)"]
if fC:
names += ["f_C"]
if fT:
names += ["f_T"]

Check warning on line 3011 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3006-L3011

Added lines #L3006 - L3011 were not covered by tests

if log:
if fB:
ax.semilogy(
rho,
f_B,
ls=ls[0 % len(ls)],
c=colors[0 % len(colors)],
marker=markers[0 % len(markers)],
label=labels[0 % len(labels)],
lw=lw[0 % len(lw)],
)
if fC:
ax.semilogy(
rho,
f_C,
ls=ls[1 % len(ls)],
c=colors[1 % len(colors)],
marker=markers[1 % len(markers)],
label=labels[1 % len(labels)],
lw=lw[1 % len(lw)],
)
if fT:
ax.semilogy(
rho,
f_T,
ls=ls[2 % len(ls)],
c=colors[2 % len(colors)],
marker=markers[2 % len(markers)],
label=labels[2 % len(labels)],
lw=lw[2 % len(lw)],
)
else:
if fB:
ax.plot(
rho,
f_B,
ls=ls[0 % len(ls)],
c=colors[0 % len(colors)],
marker=markers[0 % len(markers)],
label=labels[0 % len(labels)],
lw=lw[0 % len(lw)],
)
if fC:
ax.plot(
rho,
f_C,
ls=ls[1 % len(ls)],
c=colors[1 % len(colors)],
marker=markers[1 % len(markers)],
label=labels[1 % len(labels)],
lw=lw[1 % len(lw)],
)
if fT:
ax.plot(
rho,
f_T,
ls=ls[2 % len(ls)],
c=colors[2 % len(colors)],
marker=markers[2 % len(markers)],
label=labels[2 % len(labels)],
lw=lw[2 % len(lw)],
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data = eq.compute(

Check warning on line 3015 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3013-L3015

Added lines #L3013 - L3015 were not covered by tests
names, grid=grid, M_booz=M_booz, N_booz=N_booz, helicity=helicity
)

if fB:
B_mn = data["|B|_mn"].reshape((len(rho), -1))
B_mn = (matrix @ B_mn.T).T
f_B = np.sqrt(np.sum(B_mn[:, idx] ** 2, axis=-1)) / np.sqrt(

Check warning on line 3022 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3019-L3022

Added lines #L3019 - L3022 were not covered by tests
np.sum(B_mn**2, axis=-1)
)
plot_data["f_B"] = f_B
if fC:
sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz")
f_C = grid.meshgrid_reshape(data["f_C"], "rtz")
f_C = (

Check warning on line 3029 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3025-L3029

Added lines #L3025 - L3029 were not covered by tests
np.mean(np.abs(f_C) * sqrtg, axis=(1, 2))
/ np.mean(sqrtg, axis=(1, 2))
/ B0**3
)
plot_data["f_C"] = f_C
if fT:
sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz")
f_T = grid.meshgrid_reshape(data["f_T"], "rtz")
f_T = (

Check warning on line 3038 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3034-L3038

Added lines #L3034 - L3038 were not covered by tests
np.mean(np.abs(f_T) * sqrtg, axis=(1, 2))
/ np.mean(sqrtg, axis=(1, 2))
* R0**2
/ B0**4
)
plot_data["f_T"] = f_T

Check warning on line 3044 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3044

Added line #L3044 was not covered by tests

plot_op = ax.semilogy if log else ax.plot

Check warning on line 3046 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3046

Added line #L3046 was not covered by tests

if fB:
plot_op(

Check warning on line 3049 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3048-L3049

Added lines #L3048 - L3049 were not covered by tests
rho,
f_B,
ls=ls[0 % len(ls)],
c=colors[0 % len(colors)],
marker=markers[0 % len(markers)],
label=labels[0 % len(labels)],
lw=lw[0 % len(lw)],
)
if fC:
plot_op(

Check warning on line 3059 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3058-L3059

Added lines #L3058 - L3059 were not covered by tests
rho,
f_C,
ls=ls[1 % len(ls)],
c=colors[1 % len(colors)],
marker=markers[1 % len(markers)],
label=labels[1 % len(labels)],
lw=lw[1 % len(lw)],
)
if fT:
plot_op(

Check warning on line 3069 in desc/plotting.py

View check run for this annotation

Codecov / codecov/patch

desc/plotting.py#L3068-L3069

Added lines #L3068 - L3069 were not covered by tests
rho,
f_T,
ls=ls[2 % len(ls)],
c=colors[2 % len(colors)],
marker=markers[2 % len(markers)],
label=labels[2 % len(labels)],
lw=lw[2 % len(lw)],
)

ax.set_xlabel(_AXIS_LABELS_RTZ[0], fontsize=xlabel_fontsize)
if ylabel:
Expand Down

0 comments on commit 36bb102

Please sign in to comment.