From 36bb102dcad1684e32c0220222339da7f9d885df Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 26 Aug 2024 17:10:30 -0400 Subject: [PATCH] Simplify boozer mode plotting --- desc/plotting.py | 217 +++++++++++++++++++++-------------------------- 1 file changed, 95 insertions(+), 122 deletions(-) diff --git a/desc/plotting.py b/desc/plotting.py index f43f408af..ea641dbb4 100644 --- a/desc/plotting.py +++ b/desc/plotting.py @@ -2576,7 +2576,7 @@ def plot_boozer_modes( # noqa: C901 elif np.isscalar(rho) and rho > 1: rho = np.linspace(1, 0, num=rho, endpoint=False) - B_mn = np.array([[]]) + rho = np.sort(rho) M_booz = kwargs.pop("M_booz", 2 * eq.M) N_booz = kwargs.pop("N_booz", 2 * eq.N) linestyle = kwargs.pop("ls", "-") @@ -2594,16 +2594,15 @@ def plot_boozer_modes( # noqa: C901 else: matrix, modes = ptolemy_linear_transform(basis.modes) - for i, r in enumerate(rho): - grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r)) - transforms = get_transforms( - "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("|B|_mn", grid=grid, transforms=transforms) - b_mn = np.atleast_2d(matrix @ data["|B|_mn"]) - B_mn = np.vstack((B_mn, b_mn)) if B_mn.size else b_mn + grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho) + transforms = get_transforms( + "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = eq.compute("|B|_mn", grid=grid, transforms=transforms) + B_mn = data["|B|_mn"].reshape((len(rho), -1)) + B_mn = np.atleast_2d(matrix @ B_mn.T).T zidx = np.where((modes[:, 1:] == np.array([[0, 0]])).all(axis=1))[0] if norm: @@ -2972,6 +2971,7 @@ def plot_qs_error( # noqa: 16 fxn too complex rho = np.linspace(1, 0, num=20, endpoint=False) elif np.isscalar(rho) and rho > 1: rho = np.linspace(1, 0, num=rho, endpoint=False) + rho = np.sort(rho) fig, ax = _format_ax(ax, figsize=kwargs.pop("figsize", None)) @@ -2989,119 +2989,92 @@ def plot_qs_error( # noqa: 16 fxn too complex R0 = data["R0"] B0 = np.mean(data["|B|"] * data["sqrt(g)"]) / np.mean(data["sqrt(g)"]) - f_B = np.array([]) - f_C = np.array([]) - f_T = np.array([]) - plot_data = {} - for i, r in enumerate(rho): - grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r)) - if fB: - transforms = get_transforms( - "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz - ) - if i == 0: # only need to do this once for the first rho surface - matrix, modes, idx = ptolemy_linear_transform( - transforms["B"].basis.modes, - helicity=helicity, - NFP=transforms["B"].basis.NFP, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute( - ["|B|_mn", "B modes"], grid=grid, transforms=transforms - ) - B_mn = matrix @ data["|B|_mn"] - f_b = np.sqrt(np.sum(B_mn[idx] ** 2)) / np.sqrt(np.sum(B_mn**2)) - f_B = np.append(f_B, f_b) - if fC: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("f_C", grid=grid, helicity=helicity) - f_c = ( - np.mean(np.abs(data["f_C"]) * data["sqrt(g)"]) - / np.mean(data["sqrt(g)"]) - / B0**3 - ) - f_C = np.append(f_C, f_c) - if fT: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("f_T", grid=grid) - f_t = ( - np.mean(np.abs(data["f_T"]) * data["sqrt(g)"]) - / np.mean(data["sqrt(g)"]) - * R0**2 - / B0**4 - ) - f_T = np.append(f_T, f_t) + plot_data = {"rho": rho} - plot_data["f_B"] = f_B - plot_data["f_C"] = f_C - plot_data["f_T"] = f_T - plot_data["rho"] = rho + grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho) + names = [] + if fB: + names += ["|B|_mn"] + transforms = get_transforms( + "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz + ) + matrix, modes, idx = ptolemy_linear_transform( + transforms["B"].basis.modes, + helicity=helicity, + NFP=transforms["B"].basis.NFP, + ) + if fC or fT: + names += ["sqrt(g)"] + if fC: + names += ["f_C"] + if fT: + names += ["f_T"] - if log: - if fB: - ax.semilogy( - rho, - f_B, - ls=ls[0 % len(ls)], - c=colors[0 % len(colors)], - marker=markers[0 % len(markers)], - label=labels[0 % len(labels)], - lw=lw[0 % len(lw)], - ) - if fC: - ax.semilogy( - rho, - f_C, - ls=ls[1 % len(ls)], - c=colors[1 % len(colors)], - marker=markers[1 % len(markers)], - label=labels[1 % len(labels)], - lw=lw[1 % len(lw)], - ) - if fT: - ax.semilogy( - rho, - f_T, - ls=ls[2 % len(ls)], - c=colors[2 % len(colors)], - marker=markers[2 % len(markers)], - label=labels[2 % len(labels)], - lw=lw[2 % len(lw)], - ) - else: - if fB: - ax.plot( - rho, - f_B, - ls=ls[0 % len(ls)], - c=colors[0 % len(colors)], - marker=markers[0 % len(markers)], - label=labels[0 % len(labels)], - lw=lw[0 % len(lw)], - ) - if fC: - ax.plot( - rho, - f_C, - ls=ls[1 % len(ls)], - c=colors[1 % len(colors)], - marker=markers[1 % len(markers)], - label=labels[1 % len(labels)], - lw=lw[1 % len(lw)], - ) - if fT: - ax.plot( - rho, - f_T, - ls=ls[2 % len(ls)], - c=colors[2 % len(colors)], - marker=markers[2 % len(markers)], - label=labels[2 % len(labels)], - lw=lw[2 % len(lw)], - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = eq.compute( + names, grid=grid, M_booz=M_booz, N_booz=N_booz, helicity=helicity + ) + + if fB: + B_mn = data["|B|_mn"].reshape((len(rho), -1)) + B_mn = (matrix @ B_mn.T).T + f_B = np.sqrt(np.sum(B_mn[:, idx] ** 2, axis=-1)) / np.sqrt( + np.sum(B_mn**2, axis=-1) + ) + plot_data["f_B"] = f_B + if fC: + sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz") + f_C = grid.meshgrid_reshape(data["f_C"], "rtz") + f_C = ( + np.mean(np.abs(f_C) * sqrtg, axis=(1, 2)) + / np.mean(sqrtg, axis=(1, 2)) + / B0**3 + ) + plot_data["f_C"] = f_C + if fT: + sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz") + f_T = grid.meshgrid_reshape(data["f_T"], "rtz") + f_T = ( + np.mean(np.abs(f_T) * sqrtg, axis=(1, 2)) + / np.mean(sqrtg, axis=(1, 2)) + * R0**2 + / B0**4 + ) + plot_data["f_T"] = f_T + + plot_op = ax.semilogy if log else ax.plot + + if fB: + plot_op( + rho, + f_B, + ls=ls[0 % len(ls)], + c=colors[0 % len(colors)], + marker=markers[0 % len(markers)], + label=labels[0 % len(labels)], + lw=lw[0 % len(lw)], + ) + if fC: + plot_op( + rho, + f_C, + ls=ls[1 % len(ls)], + c=colors[1 % len(colors)], + marker=markers[1 % len(markers)], + label=labels[1 % len(labels)], + lw=lw[1 % len(lw)], + ) + if fT: + plot_op( + rho, + f_T, + ls=ls[2 % len(ls)], + c=colors[2 % len(colors)], + marker=markers[2 % len(markers)], + label=labels[2 % len(labels)], + lw=lw[2 % len(lw)], + ) ax.set_xlabel(_AXIS_LABELS_RTZ[0], fontsize=xlabel_fontsize) if ylabel: