Skip to content

Commit

Permalink
Add degenerate row check to factorize linear constraints (#1300)
Browse files Browse the repository at this point in the history
Resolves #1297 

Though what I found is that having degenerate constraints (like fixing
[0,0,0] of Rb_mn with FixBoundaryR while also fixing that mode + others
with another FixBoundaryR constraint) is only an issue when doing
`lsq-auglag` and `lsq-exact` (when ForceBalance is not a constraint),
not for `proximal-lsq-exact`... basically for when we aren't removing
the eq DOFs from the system to solve. I did not dig into the real root
but just found that removing the degenerate rows of the constraint
matrix `A` worked to fix it, probably something to do with how the
particular solution is assigned when there are degenerate rows

@YigitElma there is likely a better fix than the one I chose which was
simply to remove degenerate rows before we get into the while loop which
finds the simple constraints and adjusts A,b accordingly, if you want
feel free to make changes here
  • Loading branch information
dpanici authored Oct 16, 2024
2 parents 12d9772 + ba4fbbc commit ebb3a69
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
31 changes: 30 additions & 1 deletion desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
A = A[:, cols]
assert A.shape[1] == xp.size

# check for degenerate rows and delete if necessary
# augment A with b so that it only deletes actual degenerate constraints
# which are duplicate rows of A that also have duplicate entries of b,
# if the entries of b aren't the same then the constraints are actually
# incompatible and so we will leave those to be caught later.
A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))])
row_idx_to_delete = np.array([], dtype=int)
for row_idx in range(A_augmented.shape[0]):
# find all rows equal to this row
rows_equal_to_this_row = np.where(
np.all(A_augmented[row_idx, :] == A_augmented, axis=1)
)[0]
# find the rows equal to this row that are not this row
rows_equal_to_this_row_but_not_this_row = rows_equal_to_this_row[
rows_equal_to_this_row != row_idx
]
# if there are rows equal to this row that aren't this row, AND this particular
# row has not already been detected as a duplicate of an earlier one and slated
# for deletion, add the duplicate row indices to the array of
# rows to be deleted
if (
rows_equal_to_this_row_but_not_this_row.size
and row_idx not in row_idx_to_delete
):
row_idx_to_delete = np.append(row_idx_to_delete, rows_equal_to_this_row[1:])
# delete the affected rows, and also the corresponding rows of b
A_augmented = np.delete(A_augmented, row_idx_to_delete, axis=0)
A = A_augmented[:, :-1]
b = np.atleast_1d(A_augmented[:, -1].squeeze())

# will store the global index of the unfixed rows, idx
indices_row = np.arange(A.shape[0])
indices_idx = np.arange(A.shape[1])
Expand Down Expand Up @@ -161,7 +191,6 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
Z = np.eye(A.shape[1])
xp = put(xp, unfixed_idx, A_inv @ b)
xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx])

# cast to jnp arrays
xp = jnp.asarray(xp)
A = jnp.asarray(A)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def test_1d_optimization():
constraints = (
ForceBalance(eq=eq),
FixBoundaryR(eq=eq),
FixBoundaryR(eq=eq, modes=[0, 0, 0]), # add a degenerate constraint to confirm
# proximal-lsq-exact not affected by GH #1297
FixBoundaryZ(eq=eq, modes=eq.surface.Z_basis.modes[0:-1, :]),
FixPressure(eq=eq),
FixIota(eq=eq),
Expand Down Expand Up @@ -644,6 +646,9 @@ def test_multiobject_optimization_al():
FixParameters(surf, {"R_lmn": np.array([0]), "Z_lmn": np.array([3])}),
FixParameters(eq, {"Psi": True, "i_l": True}),
FixBoundaryR(eq, modes=[[0, 0, 0]]),
FixBoundaryR(
eq=eq, modes=[0, 0, 0]
), # add a degenerate constraint to test fix of GH #1297 for lsq-auglag
PlasmaVesselDistance(surface=surf, eq=eq, target=1),
)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def test_fixed_mode_solve():
FixIota(eq=eq),
FixPsi(eq=eq),
FixBoundaryR(eq=eq),
FixBoundaryR(
eq=eq, modes=[0, 0, 0]
), # add a degenerate constraint to test fix of GH #1297 for lsq-exact
FixBoundaryZ(eq=eq),
fixR,
fixZ,
Expand Down

0 comments on commit ebb3a69

Please sign in to comment.