Skip to content

Commit

Permalink
signed curvature (#1085)
Browse files Browse the repository at this point in the history
Resolves #1039 

Also fixes a bug in `FourierRZCurve.compute(["x_s", "x_ss", "x_sss"])`. 

Need to add tests for the following: 

- [x] new compute quantity `"center"`
- [x] `FourierPlanarCoil` with both `"xyz"` and `"rpz"` basis
- [x] update `CoilCurvature` objective tests
  • Loading branch information
ddudt authored Jul 16, 2024
2 parents ef8e9b6 + fc5d105 commit a02cf04
Show file tree
Hide file tree
Showing 14 changed files with 277 additions and 104 deletions.
7 changes: 6 additions & 1 deletion desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def current(self, new):
assert jnp.isscalar(new) or new.size == 1
self._current = float(np.squeeze(new))

@property
def num_coils(self):
"""int: Number of coils."""
return 1

def _compute_position(self, params=None, grid=None, **kwargs):
"""Compute coil positions accounting for stellarator symmetry.
Expand Down Expand Up @@ -1665,7 +1670,7 @@ def __init__(self, *coils, name="", check_intersection=True):
@property
def num_coils(self):
"""int: Number of coils."""
return sum([c.num_coils if hasattr(c, "num_coils") else 1 for c in self])
return sum([c.num_coils for c in self])

def compute(
self,
Expand Down
175 changes: 140 additions & 35 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from interpax import interp1d

from desc.backend import jnp
from desc.backend import jnp, sign

from .data_index import register_compute_fun
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
Expand Down Expand Up @@ -146,6 +146,34 @@ def _Z_Curve(params, transforms, profiles, data, **kwargs):
return data


@register_compute_fun(
name="center",
label="\\langle\\mathbf{x}\\rangle",
units="m",
units_long="meters",
description="Centroid of the curve",
dim=3,
params=["center", "rotmat", "shift"],
transforms={},
profiles=[],
coordinates="s",
data=["x"],
parameterization="desc.geometry.curve.FourierPlanarCurve",
basis_in="{'rpz', 'xyz'}: Basis for input params vectors, Default 'xyz'",
)
def _center_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
# convert to xyz
if kwargs.get("basis_in", "xyz").lower() == "rpz":
center = rpz2xyz(params["center"])
else:
center = params["center"]
# displacement and rotation
center = jnp.matmul(center, params["rotmat"].reshape((3, 3)).T) + params["shift"]
# convert back to rpz
data["center"] = xyz2rpz(center) * jnp.ones_like(data["x"])
return data


@register_compute_fun(
name="x",
label="\\mathbf{x}",
Expand Down Expand Up @@ -325,6 +353,35 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
return data


@register_compute_fun(
name="center",
label="\\langle\\mathbf{x}\\rangle",
units="m",
units_long="meters",
description="Centroid of the curve",
dim=3,
params=["R_n", "Z_n", "rotmat", "shift"],
transforms={"R": [[0, 0, 0]], "Z": [[0, 0, 0]]},
profiles=[],
coordinates="s",
data=["x"],
parameterization="desc.geometry.curve.FourierRZCurve",
)
def _center_FourierRZCurve(params, transforms, profiles, data, **kwargs):
idx_Rc = transforms["R"].basis.get_idx(N=1, error=False)
idx_Rs = transforms["R"].basis.get_idx(N=-1, error=False)
idx_Z = transforms["Z"].basis.get_idx(N=0, error=False)
X0 = params["R_n"][idx_Rc] / 2 if isinstance(idx_Rc, int) else 0
Y0 = params["R_n"][idx_Rs] / 2 if isinstance(idx_Rs, int) else 0
Z0 = params["Z_n"][idx_Z] if isinstance(idx_Z, int) else 0
center = jnp.array([X0, Y0, Z0])
# displacement and rotation
center = jnp.matmul(center, params["rotmat"].reshape((3, 3)).T) + params["shift"]
# convert back to rpz
data["center"] = xyz2rpz(center) * jnp.ones_like(data["x"])
return data


@register_compute_fun(
name="x",
label="\\mathbf{x}",
Expand Down Expand Up @@ -366,20 +423,19 @@ def _x_FourierRZCurve(params, transforms, profiles, data, **kwargs):
transforms={"R": [[0, 0, 0], [0, 0, 1]], "Z": [[0, 0, 1]], "grid": []},
profiles=[],
coordinates="s",
data=[],
data=["phi"],
parameterization="desc.geometry.curve.FourierRZCurve",
)
def _x_s_FourierRZCurve(params, transforms, profiles, data, **kwargs):
R0 = transforms["R"].transform(params["R_n"], dz=0)
dR = transforms["R"].transform(params["R_n"], dz=1)
dZ = transforms["Z"].transform(params["Z_n"], dz=1)
dphi = R0
coords = jnp.stack([dR, dphi, dZ], axis=1)
# convert to xyz for displacement and rotation
coords = jnp.stack([dR, R0, dZ], axis=1)
# convert to xyz for rotation using phi=s
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].reshape((3, 3)).T
# convert back to rpz
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
# convert back to rpz using real phi to account for displacement
coords = xyz2rpz_vec(coords, phi=data["phi"])
data["x_s"] = coords
return data

Expand All @@ -395,24 +451,20 @@ def _x_s_FourierRZCurve(params, transforms, profiles, data, **kwargs):
transforms={"R": [[0, 0, 0], [0, 0, 1], [0, 0, 2]], "Z": [[0, 0, 2]], "grid": []},
profiles=[],
coordinates="s",
data=[],
data=["phi"],
parameterization="desc.geometry.curve.FourierRZCurve",
)
def _x_ss_FourierRZCurve(params, transforms, profiles, data, **kwargs):
R0 = transforms["R"].transform(params["R_n"], dz=0)
dR = transforms["R"].transform(params["R_n"], dz=1)
d1R = transforms["R"].transform(params["R_n"], dz=1)
d2R = transforms["R"].transform(params["R_n"], dz=2)
d2Z = transforms["Z"].transform(params["Z_n"], dz=2)
R = d2R - R0
Z = d2Z
# 2nd derivative wrt phi = 0
phi = 2 * dR
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = jnp.stack([d2R - R0, 2 * d1R, d2Z], axis=1)
# convert to xyz for rotation using phi=s
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].reshape((3, 3)).T
# convert back to rpz
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
# convert back to rpz using real phi to account for displacement
coords = xyz2rpz_vec(coords, phi=data["phi"])
data["x_ss"] = coords
return data

Expand All @@ -432,28 +484,54 @@ def _x_ss_FourierRZCurve(params, transforms, profiles, data, **kwargs):
},
profiles=[],
coordinates="s",
data=[],
data=["phi"],
parameterization="desc.geometry.curve.FourierRZCurve",
)
def _x_sss_FourierRZCurve(params, transforms, profiles, data, **kwargs):
R0 = transforms["R"].transform(params["R_n"], dz=0)
dR = transforms["R"].transform(params["R_n"], dz=1)
d1R = transforms["R"].transform(params["R_n"], dz=1)
d2R = transforms["R"].transform(params["R_n"], dz=2)
d3R = transforms["R"].transform(params["R_n"], dz=3)
d3Z = transforms["Z"].transform(params["Z_n"], dz=3)
R = d3R - 3 * dR
Z = d3Z
phi = 3 * d2R - R0
coords = jnp.stack([R, phi, Z], axis=1)
# convert to xyz for displacement and rotation
coords = jnp.stack([d3R - 3 * d1R, 3 * d2R - R0, d3Z], axis=1)
# convert to xyz for rotation using phi=s
coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2])
coords = coords @ params["rotmat"].reshape((3, 3)).T
# convert back to rpz
coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2])
# convert back to rpz using real phi to account for displacement
coords = xyz2rpz_vec(coords, phi=data["phi"])
data["x_sss"] = coords
return data


@register_compute_fun(
name="center",
label="\\langle\\mathbf{x}\\rangle",
units="m",
units_long="meters",
description="Centroid of the curve",
dim=3,
params=["X_n", "Y_n", "Z_n", "rotmat", "shift"],
transforms={"X": [[0, 0, 0]], "Y": [[0, 0, 0]], "Z": [[0, 0, 0]]},
profiles=[],
coordinates="s",
data=["x"],
parameterization="desc.geometry.curve.FourierXYZCurve",
)
def _center_FourierXYZCurve(params, transforms, profiles, data, **kwargs):
idx_X = transforms["X"].basis.get_idx(N=0, error=False)
idx_Y = transforms["Y"].basis.get_idx(N=0, error=False)
idx_Z = transforms["Z"].basis.get_idx(N=0, error=False)
X0 = params["X_n"][idx_X] if isinstance(idx_X, int) else 0
Y0 = params["Y_n"][idx_Y] if isinstance(idx_Y, int) else 0
Z0 = params["Z_n"][idx_Z] if isinstance(idx_Z, int) else 0
center = jnp.array([X0, Y0, Z0])
# displacement and rotation
center = jnp.matmul(center, params["rotmat"].reshape((3, 3)).T) + params["shift"]
# convert to rpz
data["center"] = xyz2rpz(center) * jnp.ones_like(data["x"])
return data


@register_compute_fun(
name="x",
label="\\mathbf{x}",
Expand Down Expand Up @@ -568,6 +646,31 @@ def _x_sss_FourierXYZCurve(params, transforms, profiles, data, **kwargs):
return data


@register_compute_fun(
name="center",
label="\\langle\\mathbf{x}\\rangle",
units="m",
units_long="meters",
description="Centroid of the curve",
dim=3,
params=["X", "Y", "Z", "rotmat", "shift"],
transforms={},
profiles=[],
coordinates="s",
data=["x"],
parameterization="desc.geometry.curve.SplineXYZCurve",
)
def _center_SplineXYZCurve(params, transforms, profiles, data, **kwargs):
# center is average of xyz knots
xyz = jnp.stack([params["X"], params["Y"], params["Z"]], axis=1)
center = jnp.mean(xyz, axis=0)
# displacement and rotation
center = jnp.matmul(center, params["rotmat"].reshape((3, 3)).T) + params["shift"]
# convert to rpz
data["center"] = xyz2rpz(center) * jnp.ones_like(data["x"])
return data


@register_compute_fun(
name="x",
label="\\mathbf{x}",
Expand Down Expand Up @@ -794,13 +897,12 @@ def _frenet_tangent(params, transforms, profiles, data, **kwargs):
transforms={},
profiles=[],
coordinates="s",
data=["x_ss"],
data=["x_s", "x_ss"],
parameterization="desc.geometry.core.Curve",
)
def _frenet_normal(params, transforms, profiles, data, **kwargs):
data["frenet_normal"] = (
data["x_ss"] / jnp.linalg.norm(data["x_ss"], axis=-1)[:, None]
)
normal = cross(data["x_s"], cross(data["x_ss"], data["x_s"]))
data["frenet_normal"] = normal / jnp.linalg.norm(normal, axis=-1)[:, None]
return data


Expand Down Expand Up @@ -830,20 +932,23 @@ def _frenet_binormal(params, transforms, profiles, data, **kwargs):
label="\\kappa",
units="m^{-1}",
units_long="Inverse meters",
description="Scalar curvature of the curve",
description="Scalar curvature of the curve, with the sign denoting the convexity/"
+ "concavity relative to the center of the curve (a circle has positive curvature)",
dim=1,
params=[],
transforms={},
profiles=[],
coordinates="s",
data=["x_s", "x_ss"],
data=["center", "x", "x_s", "x_ss", "frenet_normal"],
parameterization="desc.geometry.core.Curve",
)
def _curvature(params, transforms, profiles, data, **kwargs):
# magnitude of curvature
dxn = jnp.linalg.norm(data["x_s"], axis=-1)[:, jnp.newaxis]
data["curvature"] = jnp.linalg.norm(
cross(data["x_s"], data["x_ss"]) / dxn**3, axis=-1
)
curvature = jnp.linalg.norm(cross(data["x_s"], data["x_ss"]) / dxn**3, axis=-1)
# sign of curvature (positive = "convex", negative = "concave")
r = data["center"] - data["x"]
data["curvature"] = curvature * sign(dot(r, data["frenet_normal"]))
return data


Expand Down
2 changes: 1 addition & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute(parameterization, names, params, transforms, profiles, data=None, **
"Tensor quantities cannot be converted to Cartesian coordinates.",
)
if data_index[p][name]["dim"] == 3: # only convert vector data
if name == "x":
if name in ["x", "center"]:
data[name] = rpz2xyz(data[name])
else:
data[name] = rpz2xyz_vec(data[name], phi=data["phi"])
Expand Down
2 changes: 1 addition & 1 deletion desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ class FourierPlanarCurve(Curve):
_io_attrs_ = Curve._io_attrs_ + ["_r_n", "_center", "_normal", "_r_basis", "_basis"]

# Reference frame is centered at the origin with normal in the +Z direction.
# The curve is computed in this frame and then shifted/rotated to the correct frame.
# Curve is computed in reference frame, then displaced/rotated to the desired frame.
def __init__(
self,
center=[10, 0, 0],
Expand Down
Loading

0 comments on commit a02cf04

Please sign in to comment.