Skip to content

Commit

Permalink
Merge branch 'master' into dp/poincare
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma authored Oct 4, 2024
2 parents eb831bf + e35436d commit 5a087dd
Show file tree
Hide file tree
Showing 23 changed files with 519 additions and 145 deletions.
177 changes: 177 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from functools import partial
from typing import Callable, Optional

from jax._src.api import (
_check_input_dtype_jacfwd,
_check_input_dtype_jacrev,
_check_output_dtype_jacfwd,
_check_output_dtype_jacrev,
_jacfwd_unravel,
_jacrev_unravel,
_jvp,
_std_basis,
_vjp,
)
from jax._src.api_util import _ensure_index, argnums_partial, check_callable
from jax._src.tree_util import tree_map, tree_structure, tree_transpose
from jax._src.util import wraps

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
Expand Down Expand Up @@ -320,3 +336,164 @@ def wrapped(*args, **kwargs):
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped


# The following section of this code is derived from JAX
# https://github.com/jax-ml/jax/blob/ff0a98a2aef958df156ca149809cf532efbbcaf4/
# jax/_src/api.py
#
# The original copyright notice is as follows
# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");


def jacfwd_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)
argnums = _ensure_index(argnums)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if not has_aux:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args)
y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args))
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
else:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap_chunked(pushfwd, chunk_size=chunk_size)(
_std_basis(dyn_args)
)
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
aux = tree_map(lambda x: x[0], aux)
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun


def jacrev_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
allow_int=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
allow_int: Optional, bool.
Whether to allow differentiating with respect to integer valued inputs. The
gradient of an integer input will have a trivial vector-space dtype (float0).
Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if not has_aux:
y, pullback = _vjp(f_partial, *dyn_args)
else:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(
tree_structure(example_args), tree_structure(y), jac_tree
)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun
39 changes: 23 additions & 16 deletions desc/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from termcolor import colored

from desc.backend import jnp, put, use_jax
from desc.utils import ensure_tuple

if use_jax:
import jax

from desc.batching import jacfwd_chunked, jacrev_chunked


class _Derivative(ABC):
"""_Derivative is an abstract base class for derivative matrix calculations.
Expand Down Expand Up @@ -123,11 +126,11 @@ class AutoDiffDerivative(_Derivative):
"""

def __init__(self, fun, argnum=0, mode="fwd", **kwargs):
def __init__(self, fun, argnum=0, mode="fwd", chunk_size=None, **kwargs):

self._fun = fun
self._argnum = argnum

self._chunk_size = chunk_size
self._set_mode(mode)

def compute(self, *args, **kwargs):
Expand Down Expand Up @@ -205,7 +208,7 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
"""
_ = kwargs.pop("rel_step", None) # unused by autodiff
argnum = (argnum,) if jnp.isscalar(argnum) else tuple(argnum)
v = (v,) if not isinstance(v, (tuple, list)) else v
v = ensure_tuple(v)

def _fun(*x):
_args = list(args)
Expand Down Expand Up @@ -241,14 +244,14 @@ def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs):
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
Expand Down Expand Up @@ -284,21 +287,21 @@ def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwarg
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)

if np.isscalar(argnum3):
argnum3 = (argnum3 + 2,)
v3 = (v3,) if not isinstance(v3, (tuple, list)) else v3
v3 = ensure_tuple(v3)
else:
argnum3 = tuple([i + 2 for i in argnum3])
v3 = tuple(v3)
Expand All @@ -323,9 +326,13 @@ def _set_mode(self, mode) -> None:

self._mode = mode
if self._mode == "fwd":
self._compute = jax.jacfwd(self._fun, self._argnum)
self._compute = jacfwd_chunked(
self._fun, self._argnum, chunk_size=self._chunk_size
)
elif self._mode == "rev":
self._compute = jax.jacrev(self._fun, self._argnum)
self._compute = jacrev_chunked(
self._fun, self._argnum, chunk_size=self._chunk_size
)
elif self._mode == "grad":
self._compute = jax.grad(self._fun, self._argnum)
elif self._mode == "hess":
Expand Down Expand Up @@ -512,7 +519,7 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
argnum = (argnum,)
else:
nargs = len(argnum)
v = (v,) if not isinstance(v, tuple) else v
v = ensure_tuple(v)

f = np.array(
[
Expand Down Expand Up @@ -549,14 +556,14 @@ def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs):
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, tuple) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, tuple) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
Expand Down Expand Up @@ -592,21 +599,21 @@ def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwarg
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, tuple) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, tuple) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)

if np.isscalar(argnum3):
argnum3 = (argnum3 + 2,)
v3 = (v3,) if not isinstance(v3, tuple) else v3
v3 = ensure_tuple(v3)
else:
argnum3 = tuple([i + 2 for i in argnum3])
v3 = tuple(v3)
Expand Down
14 changes: 8 additions & 6 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from desc.integrals import DFTInterpolator, FFTInterpolator, virtual_casing_biot_savart
from desc.nestor import Nestor
from desc.objectives.objective_funs import _Objective
from desc.utils import PRINT_WIDTH, Timer, errorif, warnif
from desc.utils import PRINT_WIDTH, Timer, errorif, parse_argname_change, warnif

from .normalization import compute_scaling_factors

Expand Down Expand Up @@ -63,7 +63,7 @@ class VacuumBoundaryError(_Objective):
"auto" selects forward or reverse mode based on the size of the input and output
of the objective. Has no effect on self.grad or self.hess which always use
reverse mode and forward over reverse mode respectively.
grid : Grid, optional
eval_grid : Grid, optional
Collocation grid containing the nodes to evaluate error at. Should be at rho=1.
Defaults to ``LinearGrid(M=eq.M_grid, N=eq.N_grid)``
field_grid : Grid, optional
Expand Down Expand Up @@ -104,15 +104,17 @@ def __init__(
normalize_target=True,
loss_function=None,
deriv_mode="auto",
grid=None,
eval_grid=None,
field_grid=None,
field_fixed=False,
name="Vacuum boundary error",
jac_chunk_size=None,
**kwargs,
):
eval_grid = parse_argname_change(eval_grid, kwargs, "grid", "eval_grid")
if target is None and bounds is None:
target = 0
self._grid = grid
self._eval_grid = eval_grid
self._eq = eq
self._field = field
self._field_grid = field_grid
Expand Down Expand Up @@ -146,12 +148,12 @@ def build(self, use_jit=True, verbose=1):
"""
eq = self.things[0]
if self._grid is None:
if self._eval_grid is None:
grid = LinearGrid(
rho=np.array([1.0]), M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=False
)
else:
grid = self._grid
grid = self._eval_grid

pres = np.max(np.abs(eq.compute("p")["p"]))
curr = np.max(np.abs(eq.compute("current")["current"]))
Expand Down
Loading

0 comments on commit 5a087dd

Please sign in to comment.