Skip to content

Commit

Permalink
Merge branch 'master' into dp/plot-3d-hide-things
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Oct 2, 2024
2 parents 3b8ec45 + b2b536c commit 459e7d0
Show file tree
Hide file tree
Showing 24 changed files with 517 additions and 200 deletions.
177 changes: 177 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from functools import partial
from typing import Callable, Optional

from jax._src.api import (
_check_input_dtype_jacfwd,
_check_input_dtype_jacrev,
_check_output_dtype_jacfwd,
_check_output_dtype_jacrev,
_jacfwd_unravel,
_jacrev_unravel,
_jvp,
_std_basis,
_vjp,
)
from jax._src.api_util import _ensure_index, argnums_partial, check_callable
from jax._src.tree_util import tree_map, tree_structure, tree_transpose
from jax._src.util import wraps

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
Expand Down Expand Up @@ -320,3 +336,164 @@ def wrapped(*args, **kwargs):
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped


# The following section of this code is derived from JAX
# https://github.com/jax-ml/jax/blob/ff0a98a2aef958df156ca149809cf532efbbcaf4/
# jax/_src/api.py
#
# The original copyright notice is as follows
# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");


def jacfwd_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)
argnums = _ensure_index(argnums)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if not has_aux:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args)
y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args))
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
else:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap_chunked(pushfwd, chunk_size=chunk_size)(
_std_basis(dyn_args)
)
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
aux = tree_map(lambda x: x[0], aux)
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun


def jacrev_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
allow_int=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
allow_int: Optional, bool.
Whether to allow differentiating with respect to integer valued inputs. The
gradient of an integer input will have a trivial vector-space dtype (float0).
Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if not has_aux:
y, pullback = _vjp(f_partial, *dyn_args)
else:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(
tree_structure(example_args), tree_structure(y), jac_tree
)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun
73 changes: 56 additions & 17 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def __repr__(self):
+ " (name={}, current={})".format(self.name, self.current)
)

def to_FourierXYZ(self, N=10, grid=None, s=None, name=""):
def to_FourierXYZ(self, N=10, grid=None, s=None, name="", **kwargs):
"""Convert coil to FourierXYZCoil representation.
Parameters
Expand Down Expand Up @@ -479,7 +479,7 @@ def to_FourierXYZ(self, N=10, grid=None, s=None, name=""):
self.current, coords, N=N, s=s, basis="xyz", name=name
)

def to_SplineXYZ(self, knots=None, grid=None, method="cubic", name=""):
def to_SplineXYZ(self, knots=None, grid=None, method="cubic", name="", **kwargs):
"""Convert coil to SplineXYZCoil.
Parameters
Expand Down Expand Up @@ -519,7 +519,7 @@ def to_SplineXYZ(self, knots=None, grid=None, method="cubic", name=""):
self.current, coords, knots=knots, method=method, name=name, basis="xyz"
)

def to_FourierRZ(self, N=10, grid=None, NFP=None, sym=False, name=""):
def to_FourierRZ(self, N=10, grid=None, NFP=None, sym=False, name="", **kwargs):
"""Convert Coil to FourierRZCoil representation.
Note that some types of coils may not be representable in this basis.
Expand Down Expand Up @@ -553,7 +553,7 @@ def to_FourierRZ(self, N=10, grid=None, NFP=None, sym=False, name=""):
self.current, coords, N=N, NFP=NFP, basis="xyz", sym=sym, name=name
)

def to_FourierPlanar(self, N=10, grid=None, basis="xyz", name=""):
def to_FourierPlanar(self, N=10, grid=None, basis="xyz", name="", **kwargs):
"""Convert Coil to FourierPlanarCoil representation.
Note that some types of coils may not be representable in this basis.
Expand Down Expand Up @@ -1705,7 +1705,7 @@ def from_symmetry(cls, coils, NFP=1, sym=False):
return cls(*coilset)

@classmethod
def from_makegrid_coilfile(cls, coil_file, method="cubic"):
def from_makegrid_coilfile(cls, coil_file, method="cubic", check_intersection=True):
"""Create a CoilSet of SplineXYZCoils from a MAKEGRID-formatted coil txtfile.
If the MAKEGRID contains more than one coil group (denoted by the number listed
Expand All @@ -1730,6 +1730,8 @@ def from_makegrid_coilfile(cls, coil_file, method="cubic"):
the data, and will not introduce new extrema in the interpolated points
- ``'monotonic-0'``: same as `'monotonic'` but with 0 first derivatives at
both endpoints
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
"""
coils = [] # list of SplineXYZCoils, ignoring coil groups
Expand Down Expand Up @@ -1805,7 +1807,7 @@ def from_makegrid_coilfile(cls, coil_file, method="cubic"):
)

try:
return cls(*coils)
return cls(*coils, check_intersection=check_intersection)
except ValueError as e:
errorif(
True,
Expand Down Expand Up @@ -2458,7 +2460,7 @@ def compute_magnetic_vector_potential(
return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A")

def to_FourierPlanar(
self, N=10, grid=None, basis="xyz", name="", check_intersection=False
self, N=10, grid=None, basis="xyz", name="", check_intersection=True
):
"""Convert all coils to FourierPlanarCoil representation.
Expand Down Expand Up @@ -2487,7 +2489,12 @@ def to_FourierPlanar(
minor radius r in a plane specified by a center position and normal vector.
"""
coils = [coil.to_FourierPlanar(N=N, grid=grid, basis=basis) for coil in self]
coils = [
coil.to_FourierPlanar(
N=N, grid=grid, basis=basis, check_intersection=check_intersection
)
for coil in self
]
return self.__class__(*coils, name=name, check_intersection=check_intersection)

def to_FourierRZ(
Expand Down Expand Up @@ -2520,7 +2527,12 @@ def to_FourierRZ(
New representation of the coilset parameterized by a Fourier series for R,Z.
"""
coils = [coil.to_FourierRZ(N=N, grid=grid, NFP=NFP, sym=sym) for coil in self]
coils = [
coil.to_FourierRZ(
N=N, grid=grid, NFP=NFP, sym=sym, check_intersection=check_intersection
)
for coil in self
]
return self.__class__(*coils, name=name, check_intersection=check_intersection)

def to_FourierXYZ(self, N=10, grid=None, s=None, name="", check_intersection=True):
Expand Down Expand Up @@ -2548,7 +2560,10 @@ def to_FourierXYZ(self, N=10, grid=None, s=None, name="", check_intersection=Tru
X,Y,Z.
"""
coils = [coil.to_FourierXYZ(N, grid, s) for coil in self]
coils = [
coil.to_FourierXYZ(N, grid, s, check_intersection=check_intersection)
for coil in self
]
return self.__class__(*coils, name=name, check_intersection=check_intersection)

def to_SplineXYZ(
Expand Down Expand Up @@ -2586,7 +2601,12 @@ def to_SplineXYZ(
New representation of the coilset parameterized by a spline for X,Y,Z.
"""
coils = [coil.to_SplineXYZ(knots, grid, method) for coil in self]
coils = [
coil.to_SplineXYZ(
knots, grid, method, check_intersection=check_intersection
)
for coil in self
]
return self.__class__(*coils, name=name, check_intersection=check_intersection)

def __add__(self, other):
Expand All @@ -2610,7 +2630,7 @@ def insert(self, i, new_item):

@classmethod
def from_makegrid_coilfile( # noqa: C901 - FIXME: simplify this
cls, coil_file, method="cubic", ignore_groups=False
cls, coil_file, method="cubic", ignore_groups=False, check_intersection=True
):
"""Create a MixedCoilSet of SplineXYZCoils from a MAKEGRID coil txtfile.
Expand Down Expand Up @@ -2645,6 +2665,9 @@ def from_makegrid_coilfile( # noqa: C901 - FIXME: simplify this
single coilgroup. If there is only a single group, however, this will not
return a nested coilset, but just a single coilset for that group. if True,
return the coils as just a single MixedCoilSet.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
"""
coils = {} # dict of list of SplineXYZCoils, one list per coilgroup
Expand Down Expand Up @@ -2741,18 +2764,34 @@ def flatten_coils(coilset):
# nested coilset
groupinds = list(coils.keys())
if len(groupinds) == 1:
return cls(*coils[groupinds[0]], name=groupnames[0])
return cls(
*coils[groupinds[0]],
name=groupnames[0],
check_intersection=check_intersection,
)

# if not, possibly return a nested coilset, containing one coilset per coilgroup
coilsets = [] # list of coilsets, so we can attempt to use CoilSet for each one
for groupname, groupind in zip(groupnames, groupinds):
try:
# try making the coilgroup use a CoilSet
coilsets.append(CoilSet(*coils[groupind], name=groupname))
coilsets.append(
CoilSet(
*coils[groupind],
name=groupname,
check_intersection=check_intersection,
)
)
except ValueError: # can't load as a CoilSet if any of the coils have
# different length of knots, so load as MixedCoilSet instead
coilsets.append(cls(*coils[groupind], name=groupname))
cset = cls(*coilsets)
coilsets.append(
cls(
*coils[groupind],
name=groupname,
check_intersection=check_intersection,
)
)
cset = cls(*coilsets, check_intersection=check_intersection)
if ignore_groups:
cset = cls(*flatten_coils(cset))
cset = cls(*flatten_coils(cset), check_intersection=check_intersection)
return cset
Loading

0 comments on commit 459e7d0

Please sign in to comment.