Skip to content

Commit

Permalink
Merge branch 'master' into single-stage-example
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Oct 16, 2024
2 parents 7be9729 + 12d9772 commit 3ebb6f5
Show file tree
Hide file tree
Showing 41 changed files with 1,283 additions and 1,896 deletions.
177 changes: 177 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from functools import partial
from typing import Callable, Optional

from jax._src.api import (
_check_input_dtype_jacfwd,
_check_input_dtype_jacrev,
_check_output_dtype_jacfwd,
_check_output_dtype_jacrev,
_jacfwd_unravel,
_jacrev_unravel,
_jvp,
_std_basis,
_vjp,
)
from jax._src.api_util import _ensure_index, argnums_partial, check_callable
from jax._src.tree_util import tree_map, tree_structure, tree_transpose
from jax._src.util import wraps

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
Expand Down Expand Up @@ -320,3 +336,164 @@ def wrapped(*args, **kwargs):
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped


# The following section of this code is derived from JAX
# https://github.com/jax-ml/jax/blob/ff0a98a2aef958df156ca149809cf532efbbcaf4/
# jax/_src/api.py
#
# The original copyright notice is as follows
# Copyright 2018 The JAX Authors.
# Licensed under the Apache License, Version 2.0 (the "License");


def jacfwd_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)
argnums = _ensure_index(argnums)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if not has_aux:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args)
y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args))
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
else:
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap_chunked(pushfwd, chunk_size=chunk_size)(
_std_basis(dyn_args)
)
y = tree_map(lambda x: x[0], y)
jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac)
aux = tree_map(lambda x: x[0], aux)
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun


def jacrev_chunked(
fun,
argnums=0,
has_aux=False,
holomorphic=False,
allow_int=False,
*,
chunk_size=None,
):
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Parameters
----------
fun: callable
Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers.
Specifies which positional argument(s) to differentiate with respect to
(default ``0``).
has_aux: Optional, bool.
Indicates whether ``fun`` returns a pair where the first element is considered
the output of the mathematical function to be differentiated and the second
element is auxiliary data. Default False.
holomorphic: Optional, bool.
Indicates whether ``fun`` is promised to be holomorphic. Default False.
allow_int: Optional, bool.
Whether to allow differentiating with respect to integer valued inputs. The
gradient of an integer input will have a trivial vector-space dtype (float0).
Default False.
chunk_size: int
The size of the batches to pass to vmap. If None, defaults to the largest
possible chunk_size.
Returns
-------
jac: callable
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
"""
check_callable(fun)

docstr = (
"Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}."
)

@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if not has_aux:
y, pullback = _vjp(f_partial, *dyn_args)
else:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(
tree_structure(example_args), tree_structure(y), jac_tree
)
if not has_aux:
return jac_tree
else:
return jac_tree, aux

return jacfun
Loading

0 comments on commit 3ebb6f5

Please sign in to comment.