diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index ffaa507d3..161c3f057 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -92,6 +92,36 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa A = A[:, cols] assert A.shape[1] == xp.size + # check for degenerate rows and delete if necessary + # augment A with b so that it only deletes actual degenerate constraints + # which are duplicate rows of A that also have duplicate entries of b, + # if the entries of b aren't the same then the constraints are actually + # incompatible and so we will leave those to be caught later. + A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))]) + row_idx_to_delete = np.array([], dtype=int) + for row_idx in range(A_augmented.shape[0]): + # find all rows equal to this row + rows_equal_to_this_row = np.where( + np.all(A_augmented[row_idx, :] == A_augmented, axis=1) + )[0] + # find the rows equal to this row that are not this row + rows_equal_to_this_row_but_not_this_row = rows_equal_to_this_row[ + rows_equal_to_this_row != row_idx + ] + # if there are rows equal to this row that aren't this row, AND this particular + # row has not already been detected as a duplicate of an earlier one and slated + # for deletion, add the duplicate row indices to the array of + # rows to be deleted + if ( + rows_equal_to_this_row_but_not_this_row.size + and row_idx not in row_idx_to_delete + ): + row_idx_to_delete = np.append(row_idx_to_delete, rows_equal_to_this_row[1:]) + # delete the affected rows, and also the corresponding rows of b + A_augmented = np.delete(A_augmented, row_idx_to_delete, axis=0) + A = A_augmented[:, :-1] + b = np.atleast_1d(A_augmented[:, -1].squeeze()) + # will store the global index of the unfixed rows, idx indices_row = np.arange(A.shape[0]) indices_idx = np.arange(A.shape[1]) @@ -161,7 +191,6 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa Z = np.eye(A.shape[1]) xp = put(xp, unfixed_idx, A_inv @ b) xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx]) - # cast to jnp arrays xp = jnp.asarray(xp) A = jnp.asarray(A) diff --git a/tests/test_examples.py b/tests/test_examples.py index 93c700b36..95d46a994 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -178,6 +178,8 @@ def test_1d_optimization(): constraints = ( ForceBalance(eq=eq), FixBoundaryR(eq=eq), + FixBoundaryR(eq=eq, modes=[0, 0, 0]), # add a degenerate constraint to confirm + # proximal-lsq-exact not affected by GH #1297 FixBoundaryZ(eq=eq, modes=eq.surface.Z_basis.modes[0:-1, :]), FixPressure(eq=eq), FixIota(eq=eq), @@ -644,6 +646,9 @@ def test_multiobject_optimization_al(): FixParameters(surf, {"R_lmn": np.array([0]), "Z_lmn": np.array([3])}), FixParameters(eq, {"Psi": True, "i_l": True}), FixBoundaryR(eq, modes=[[0, 0, 0]]), + FixBoundaryR( + eq=eq, modes=[0, 0, 0] + ), # add a degenerate constraint to test fix of GH #1297 for lsq-auglag PlasmaVesselDistance(surface=surf, eq=eq, target=1), ) diff --git a/tests/test_linear_objectives.py b/tests/test_linear_objectives.py index 982e98d5d..d0c17f8e3 100644 --- a/tests/test_linear_objectives.py +++ b/tests/test_linear_objectives.py @@ -205,6 +205,9 @@ def test_fixed_mode_solve(): FixIota(eq=eq), FixPsi(eq=eq), FixBoundaryR(eq=eq), + FixBoundaryR( + eq=eq, modes=[0, 0, 0] + ), # add a degenerate constraint to test fix of GH #1297 for lsq-exact FixBoundaryZ(eq=eq), fixR, fixZ,