Skip to content

Commit

Permalink
Merge branch 'master' into fix-plotting-subplots
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest authored Sep 30, 2024
2 parents 93b8b61 + c2bb798 commit 930939b
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 117 deletions.
177 changes: 177 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from functools import partial
from typing import Callable, Optional

from jax._src.api import (
_check_input_dtype_jacfwd,
_check_input_dtype_jacrev,
_check_output_dtype_jacfwd,
_check_output_dtype_jacrev,
_jacfwd_unravel,
_jacrev_unravel,
_jvp,
_std_basis,
_vjp,
)
from jax._src.api_util import _ensure_index, argnums_partial, check_callable
from jax._src.tree_util import tree_map, tree_structure, tree_transpose
from jax._src.util import wraps

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
Expand Down Expand Up @@ -320,3 +336,164 @@ def wrapped(*args, **kwargs):
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped


# The following section of this code is derived from JAX
# https://github.com/jax-ml/jax/blob/ff0a98a2aef958df156ca149809cf532efbbcaf4/
# jax/_src/api.py
#
# The original copyright notice is as follows
# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");


def jacfwd_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)
argnums = _ensure_index(argnums)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if not has_aux:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args)
y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args))
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
else:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap_chunked(pushfwd, chunk_size=chunk_size)(
_std_basis(dyn_args)
)
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
aux = tree_map(lambda x: x[0], aux)
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun


def jacrev_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
allow_int=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
allow_int: Optional, bool.
Whether to allow differentiating with respect to integer valued inputs. The
gradient of an integer input will have a trivial vector-space dtype (float0).
Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if not has_aux:
y, pullback = _vjp(f_partial, *dyn_args)
else:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(
tree_structure(example_args), tree_structure(y), jac_tree
)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun
39 changes: 23 additions & 16 deletions desc/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from termcolor import colored

from desc.backend import jnp, put, use_jax
from desc.utils import ensure_tuple

if use_jax:
import jax

from desc.batching import jacfwd_chunked, jacrev_chunked


class _Derivative(ABC):
"""_Derivative is an abstract base class for derivative matrix calculations.
Expand Down Expand Up @@ -123,11 +126,11 @@ class AutoDiffDerivative(_Derivative):
"""

def __init__(self, fun, argnum=0, mode="fwd", **kwargs):
def __init__(self, fun, argnum=0, mode="fwd", chunk_size=None, **kwargs):

self._fun = fun
self._argnum = argnum

self._chunk_size = chunk_size
self._set_mode(mode)

def compute(self, *args, **kwargs):
Expand Down Expand Up @@ -205,7 +208,7 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
"""
_ = kwargs.pop("rel_step", None) # unused by autodiff
argnum = (argnum,) if jnp.isscalar(argnum) else tuple(argnum)
v = (v,) if not isinstance(v, (tuple, list)) else v
v = ensure_tuple(v)

def _fun(*x):
_args = list(args)
Expand Down Expand Up @@ -241,14 +244,14 @@ def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs):
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
Expand Down Expand Up @@ -284,21 +287,21 @@ def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwarg
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)

if np.isscalar(argnum3):
argnum3 = (argnum3 + 2,)
v3 = (v3,) if not isinstance(v3, (tuple, list)) else v3
v3 = ensure_tuple(v3)
else:
argnum3 = tuple([i + 2 for i in argnum3])
v3 = tuple(v3)
Expand All @@ -323,9 +326,13 @@ def _set_mode(self, mode) -> None:

self._mode = mode
if self._mode == "fwd":
self._compute = jax.jacfwd(self._fun, self._argnum)
self._compute = jacfwd_chunked(
self._fun, self._argnum, chunk_size=self._chunk_size
)
elif self._mode == "rev":
self._compute = jax.jacrev(self._fun, self._argnum)
self._compute = jacrev_chunked(
self._fun, self._argnum, chunk_size=self._chunk_size
)
elif self._mode == "grad":
self._compute = jax.grad(self._fun, self._argnum)
elif self._mode == "hess":
Expand Down Expand Up @@ -512,7 +519,7 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
argnum = (argnum,)
else:
nargs = len(argnum)
v = (v,) if not isinstance(v, tuple) else v
v = ensure_tuple(v)

f = np.array(
[
Expand Down Expand Up @@ -549,14 +556,14 @@ def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs):
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, tuple) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, tuple) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)
Expand Down Expand Up @@ -592,21 +599,21 @@ def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwarg
"""
if np.isscalar(argnum1):
v1 = (v1,) if not isinstance(v1, tuple) else v1
v1 = ensure_tuple(v1)
argnum1 = (argnum1,)
else:
v1 = tuple(v1)

if np.isscalar(argnum2):
argnum2 = (argnum2 + 1,)
v2 = (v2,) if not isinstance(v2, tuple) else v2
v2 = ensure_tuple(v2)
else:
argnum2 = tuple([i + 1 for i in argnum2])
v2 = tuple(v2)

if np.isscalar(argnum3):
argnum3 = (argnum3 + 2,)
v3 = (v3,) if not isinstance(v3, tuple) else v3
v3 = ensure_tuple(v3)
else:
argnum3 = tuple([i + 2 for i in argnum3])
v3 = tuple(v3)
Expand Down
Loading

0 comments on commit 930939b

Please sign in to comment.