From 7188140c436b4c2e9b838d53b886a523ea71eebd Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 21 Jul 2023 16:23:19 -0400 Subject: [PATCH 1/2] Move update_target method to new subclass just for fixed constraints --- desc/objectives/linear_objectives.py | 194 +++++++------------------- desc/objectives/objective_funs.py | 18 --- desc/optimize/_constraint_wrappers.py | 6 +- desc/perturbations.py | 6 +- 4 files changed, 56 insertions(+), 168 deletions(-) diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index 51cff4a2f4..1b3a0b31d4 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -19,6 +19,26 @@ from .objective_funs import _Objective +class _FixedObjective(_Objective): + + _fixed = True + _linear = True + _scalar = False + + def update_target(self, eq): + """Update target values using an Equilibrium. + + Parameters + ---------- + eq : Equilibrium + Equilibrium that will be optimized to satisfy the Objective. + + """ + self.target = np.atleast_1d(getattr(eq, self._target_arg, self.target)) + if self._use_jit: + self.jit() + + class BoundaryRSelfConsistency(_Objective): """Ensure that the boundary and interior surfaces are self consistent. @@ -331,7 +351,7 @@ def compute(self, *args, **kwargs): return f -class FixBoundaryR(_Objective): +class FixBoundaryR(_FixedObjective): """Boundary condition on the R boundary parameters. Parameters @@ -369,9 +389,7 @@ class FixBoundaryR(_Objective): `basis.modes` which may be different from the order that was passed in. """ - _scalar = False - _linear = True - _fixed = False + _target_arg = "Rb_lmn" _units = "(m)" _print_value_fmt = "R boundary error: {:10.3e} " @@ -477,13 +495,8 @@ def compute(self, *args, **kwargs): params, _ = self._parse_args(*args, **kwargs) return jnp.dot(self._A, params["Rb_lmn"]) - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Rb_lmn" - -class FixBoundaryZ(_Objective): +class FixBoundaryZ(_FixedObjective): """Boundary condition on the Z boundary parameters. Parameters @@ -521,9 +534,7 @@ class FixBoundaryZ(_Objective): `basis.modes` which may be different from the order that was passed in. """ - _scalar = False - _linear = True - _fixed = False + _target_arg = "Zb_lmn" _units = "(m)" _print_value_fmt = "Z boundary error: {:10.3e} " @@ -629,11 +640,6 @@ def compute(self, *args, **kwargs): params, _ = self._parse_args(*args, **kwargs) return jnp.dot(self._A, params["Zb_lmn"]) - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Zb_lmn" - class FixLambdaGauge(_Objective): """Fixes gauge freedom for lambda: lambda(theta=0,zeta=0)=0. @@ -795,13 +801,8 @@ def compute(self, L_lmn, **kwargs): fixed_params = L_lmn[self._idx] return fixed_params - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "L_lmn" - -class FixAxisR(_Objective): +class FixAxisR(_FixedObjective): """Fixes magnetic axis R coefficients. Parameters @@ -831,9 +832,7 @@ class FixAxisR(_Objective): """ - _scalar = False - _linear = True - _fixed = False + _target_arg = "Ra_n" _units = "(m)" _print_value_fmt = "R axis error: {:10.3e} " @@ -949,13 +948,8 @@ def compute(self, Ra_n, **kwargs): f = jnp.dot(self._A, Ra_n) return f - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Ra_n" - -class FixAxisZ(_Objective): +class FixAxisZ(_FixedObjective): """Fixes magnetic axis Z coefficients. Parameters @@ -985,9 +979,7 @@ class FixAxisZ(_Objective): """ - _scalar = False - _linear = True - _fixed = False + _target_arg = "Za_n" _units = "(m)" _print_value_fmt = "Z axis error: {:10.3e} " @@ -1103,13 +1095,8 @@ def compute(self, Za_n, **kwargs): f = jnp.dot(self._A, Za_n) return f - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Za_n" - -class FixModeR(_Objective): +class FixModeR(_FixedObjective): """Fixes Fourier-Zernike R coefficients. Parameters @@ -1141,9 +1128,7 @@ class FixModeR(_Objective): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "R_lmn" _units = "(m)" _print_value_fmt = "Fixed-R modes error: {:10.3e} " @@ -1248,13 +1233,8 @@ def compute(self, R_lmn, **kwargs): fixed_params = R_lmn[self._idx] return fixed_params - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "R_lmn" - -class FixModeZ(_Objective): +class FixModeZ(_FixedObjective): """Fixes Fourier-Zernike Z coefficients. Parameters @@ -1286,9 +1266,7 @@ class FixModeZ(_Objective): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "Z_lmn" _units = "(m)" _print_value_fmt = "Fixed-Z modes error: {:10.3e} " @@ -1393,13 +1371,8 @@ def compute(self, Z_lmn, **kwargs): fixed_params = Z_lmn[self._idx] return fixed_params - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Z_lmn" - -class FixSumModesR(_Objective): +class FixSumModesR(_FixedObjective): """Fixes a linear sum of Fourier-Zernike R coefficients. Parameters @@ -1437,9 +1410,8 @@ class FixSumModesR(_Objective): """ - _scalar = False - _linear = True - _fixed = False + _target_arg = "R_lmn" + _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" _print_value_fmt = "Fixed-R sum modes error: {:10.3e} " @@ -1558,13 +1530,8 @@ def compute(self, R_lmn, **kwargs): f = jnp.dot(self._A, R_lmn) return f - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "R_lmn" - -class FixSumModesZ(_Objective): +class FixSumModesZ(_FixedObjective): """Fixes a linear sum of Fourier-Zernike Z coefficients. Parameters @@ -1602,9 +1569,8 @@ class FixSumModesZ(_Objective): """ - _scalar = False - _linear = True - _fixed = False + _target_arg = "Z_lmn" + _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" _print_value_fmt = "Fixed-Z sum modes error: {:10.3e} " @@ -1724,13 +1690,8 @@ def compute(self, Z_lmn, **kwargs): f = jnp.dot(self._A, Z_lmn) return f - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Z_lmn" - -class _FixProfile(_Objective, ABC): +class _FixProfile(_FixedObjective, ABC): """Fixes profile coefficients (or values, for SplineProfile). Parameters @@ -1767,9 +1728,6 @@ class _FixProfile(_Objective, ABC): """ - _scalar = False - _linear = True - _fixed = True _print_value_fmt = "Fix-profile error: {:10.3e} " def __init__( @@ -1868,9 +1826,7 @@ class FixPressure(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "p_l" _units = "(Pa)" _print_value_fmt = "Fixed-pressure profile error: {:10.3e} " @@ -1940,11 +1896,6 @@ def compute(self, p_l, **kwargs): """ return p_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "p_l" - class FixIota(_FixProfile): """Fixes rotational transform coefficients. @@ -1983,9 +1934,7 @@ class FixIota(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "i_l" _units = "(dimensionless)" _print_value_fmt = "Fixed-iota profile error: {:10.3e} " @@ -2052,11 +2001,6 @@ def compute(self, i_l, **kwargs): """ return i_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "i_l" - class FixCurrent(_FixProfile): """Fixes toroidal current profile coefficients. @@ -2093,9 +2037,7 @@ class FixCurrent(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "c_l" _units = "(A)" _print_value_fmt = "Fixed-current profile error: {:10.3e} " @@ -2165,11 +2107,6 @@ def compute(self, c_l, **kwargs): """ return c_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "c_l" - class FixElectronTemperature(_FixProfile): """Fixes electron temperature profile coefficients. @@ -2206,9 +2143,7 @@ class FixElectronTemperature(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "Te_l" _units = "(eV)" _print_value_fmt = "Fixed-electron-temperature profile error: {:10.3e} " @@ -2278,11 +2213,6 @@ def compute(self, Te_l, **kwargs): """ return Te_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Te_l" - class FixElectronDensity(_FixProfile): """Fixes electron density profile coefficients. @@ -2319,9 +2249,7 @@ class FixElectronDensity(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "ne_l" _units = "(m^-3)" _print_value_fmt = "Fixed-electron-density profile error: {:10.3e} " @@ -2391,11 +2319,6 @@ def compute(self, ne_l, **kwargs): """ return ne_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "ne_l" - class FixIonTemperature(_FixProfile): """Fixes ion temperature profile coefficients. @@ -2432,9 +2355,7 @@ class FixIonTemperature(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "Ti_l" _units = "(eV)" _print_value_fmt = "Fixed-ion-temperature profile error: {:10.3e} " @@ -2504,11 +2425,6 @@ def compute(self, Ti_l, **kwargs): """ return Ti_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Ti_l" - class FixAtomicNumber(_FixProfile): """Fixes effective atomic number profile coefficients. @@ -2547,9 +2463,7 @@ class FixAtomicNumber(_FixProfile): """ - _scalar = False - _linear = True - _fixed = True + _target_arg = "Zeff_l" _units = "(dimensionless)" _print_value_fmt = "Fixed-atomic-number profile error: {:10.3e} " @@ -2616,13 +2530,8 @@ def compute(self, Zeff_l, **kwargs): """ return Zeff_l[self._idx] - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Zeff_l" - -class FixPsi(_Objective): +class FixPsi(_FixedObjective): """Fixes total toroidal magnetic flux within the last closed flux surface. Parameters @@ -2646,9 +2555,7 @@ class FixPsi(_Objective): """ - _scalar = True - _linear = True - _fixed = True + _target_arg = "Psi" _units = "(Wb)" _print_value_fmt = "Fixed-Psi error: {:10.3e} " @@ -2713,8 +2620,3 @@ def compute(self, Psi, **kwargs): """ return Psi - - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "Psi" diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 79e0a05648..bc030c5571 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -928,19 +928,6 @@ def _check_dimensions(self): if not is_broadcastable((self.dim_f,), self.weight.shape): raise ValueError("len(weight) != dim_f") - def update_target(self, eq): - """Update target values using an Equilibrium. - - Parameters - ---------- - eq : Equilibrium - Equilibrium that will be optimized to satisfy the Objective. - - """ - self.target = np.atleast_1d(getattr(eq, self.target_arg, self.target)) - if self._use_jit: - self.jit() - @abstractmethod def build(self, eq=None, use_jit=True, verbose=1): """Build constant arrays.""" @@ -1102,11 +1089,6 @@ def args(self): """list: Names (str) of arguments to the compute functions.""" return self._args - @property - def target_arg(self): - """str: Name of argument corresponding to the target.""" - return "" - @property def dimensions(self): """dict: Dimensions of the argument given by the dict keys.""" diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 4c23272522..269d6a993d 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -670,14 +670,16 @@ def _update_equilibrium(self, x, store=False): if val.size: setattr(self._eq, arg, val) for con in self._linear_constraints: - con.update_target(self._eq) + if hasattr(con, "update_target"): + con.update_target(self._eq) else: for arg in arg_order: val = self.history[arg][-1].copy() if val.size: setattr(self._eq, arg, val) for con in self._linear_constraints: - con.update_target(self._eq) + if hasattr(con, "update_target"): + con.update_target(self._eq) return xopt, xeq def compute_scaled(self, x, constants=None): diff --git a/desc/perturbations.py b/desc/perturbations.py index f65953a67f..148512ab14 100644 --- a/desc/perturbations.py +++ b/desc/perturbations.py @@ -399,7 +399,8 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces for key, value in deltas.items(): setattr(eq_new, key, getattr(eq_new, key) + value) for constraint in constraints: - constraint.update_target(eq_new) + if hasattr(constraint, "update_target"): + constraint.update_target(eq_new) xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints( constraints, objective.args ) @@ -815,7 +816,8 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces setattr(eq_new, key, getattr(eq_new, key) + dc[idx0 : idx0 + len(value)]) idx0 += len(value) for constraint in constraints: - constraint.update_target(eq_new) + if hasattr(constraint, "update_target"): + constraint.update_target(eq_new) xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints( constraints, objective_f.args ) From 45ca19b7ce32df18a791dd7feb966a4503742180 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 24 Jul 2023 17:18:25 -0400 Subject: [PATCH 2/2] Only warn about unequal NFP when it matters --- desc/transform.py | 1 + tests/test_transform.py | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/desc/transform.py b/desc/transform.py index 2cbf064f66..63453d6e35 100644 --- a/desc/transform.py +++ b/desc/transform.py @@ -59,6 +59,7 @@ def __init__( if ( not np.all(self.grid.nodes[:, 2] == 0) + and self.basis.N != 0 and not (self.grid.NFP == self.basis.NFP) and grid.node_pattern != "custom" ): diff --git a/tests/test_transform.py b/tests/test_transform.py index 1f5572584e..240f63a5db 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -625,3 +625,49 @@ def bar(x): x = np.random.random(basis.num_modes) np.testing.assert_allclose(foo(x, transform), transform.transform(x)) np.testing.assert_allclose(bar(x), transform.transform(x)) + + +@pytest.mark.unit +def test_NFP_warning(): + """Make sure we only warn about basis/grid NFP in cases where it matters.""" + rho = np.linspace(0, 1, 20) + g01 = LinearGrid(rho=rho, L=5, N=0, NFP=1) + g02 = LinearGrid(rho=rho, L=5, N=0, NFP=2) + g21 = LinearGrid(rho=rho, L=5, N=5, NFP=1) + g22 = LinearGrid(rho=rho, L=5, N=5, NFP=2) + b01 = FourierZernikeBasis(L=2, M=2, N=0, NFP=1) + b02 = FourierZernikeBasis(L=2, M=2, N=0, NFP=2) + b21 = FourierZernikeBasis(L=2, M=2, N=2, NFP=1) + b22 = FourierZernikeBasis(L=2, M=2, N=2, NFP=2) + + # No toroidal nodes, shouldn't warn + _ = Transform(g01, b01) + _ = Transform(g01, b02) + _ = Transform(g01, b21) + _ = Transform(g01, b22) + + # No toroidal nodes, shouldn't warn + _ = Transform(g02, b01) + _ = Transform(g02, b02) + _ = Transform(g02, b21) + _ = Transform(g02, b22) + + # toroidal nodes but no toroidal modes, no warning + _ = Transform(g21, b01) + # toroidal nodes but no toroidal modes, no warning + _ = Transform(g21, b02) + # toroidal nodes and modes, but equal nfp, no warning + _ = Transform(g21, b21) + # toroidal modes and nodes and unequal NFP -> warning + with pytest.warns(UserWarning): + _ = Transform(g21, b22) + + # no toroidal modes, no warning + _ = Transform(g22, b01) + # no toroidal modes, no warning + _ = Transform(g22, b02) + # toroidal modes and nodes and unequal NFP -> warning + with pytest.warns(UserWarning): + _ = Transform(g22, b21) + # toroidal nodes and modes, but equal nfp, no warning + _ = Transform(g22, b22)