Skip to content

Commit

Permalink
Throw error for when Curve is used for coil objectives (#1141)
Browse files Browse the repository at this point in the history
Resolves #1020 

The given example in #1020 uses a `FourierPlanarCurve` which actually
results in a list of numpy arrays being created after going through
`tree_flatten`. This PR gives a better error message, but I'm not sure
if we can explicitly say that a `Curve` was inputted.

If there is a `Curve` object in a `CoilSet` then there are already
checks in the `CoilSet` class that throw an error.
  • Loading branch information
dpanici authored Aug 28, 2024
2 parents 301a894 + f438cef commit 04e6e56
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
6 changes: 6 additions & 0 deletions desc/objectives/_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def _prune_coilset_tree(coilset):

# get individual coils from coilset
coils, structure = tree_flatten(coil, is_leaf=_is_single_coil)
for c in coils:
errorif(
not isinstance(c, _Coil),
TypeError,
f"Expected object of type Coil, got {type(c)}",
)
self._num_coils = len(coils)

# map grid to list of length coils
Expand Down
9 changes: 8 additions & 1 deletion tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from desc.compute import get_transforms
from desc.equilibrium import Equilibrium
from desc.examples import get
from desc.geometry import FourierRZToroidalSurface, FourierXYZCurve
from desc.geometry import FourierPlanarCurve, FourierRZToroidalSurface, FourierXYZCurve
from desc.grid import ConcentricGrid, LinearGrid, QuadratureGrid
from desc.io import load
from desc.magnetic_fields import (
Expand Down Expand Up @@ -869,6 +869,13 @@ def test(coil, grid=None):
test(mixed_coils)
test(nested_coils, grid=grid)

def test_coil_type_error(self):
"""Tests error when objective is not passed a coil."""
curve = FourierPlanarCurve(r_n=2, basis="rpz")
obj = CoilLength(curve)
with pytest.raises(TypeError):
obj.build()

@pytest.mark.unit
def test_coil_min_distance(self):
"""Tests minimum distance between coils in a coilset."""
Expand Down

0 comments on commit 04e6e56

Please sign in to comment.