From 1c134525c145fe7ede5b3995d15cc7628acd5ef6 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 27 Sep 2024 18:41:14 -0400 Subject: [PATCH 1/6] Add jacfwd_chunked and jacrev_chunked --- desc/batching.py | 406 +++++++++++++++++++++++++++++++++++++- desc/derivatives.py | 14 +- tests/test_derivatives.py | 12 +- 3 files changed, 423 insertions(+), 9 deletions(-) diff --git a/desc/batching.py b/desc/batching.py index 7b2a18f7b..c215ad55d 100644 --- a/desc/batching.py +++ b/desc/batching.py @@ -1,7 +1,30 @@ """Utility functions for the ``batched_vectorize`` function.""" import functools -from typing import Callable, Optional +from functools import partial +from typing import Any, Callable, Optional, Sequence + +import numpy as np +from jax._src import core, dispatch, dtypes +from jax._src.api_util import ( + _ensure_index, + argnums_partial, + check_callable, + flatten_fun_nokwargs, + flatten_fun_nokwargs2, + shaped_abstractify, +) +from jax._src.interpreters import ad +from jax._src.lax import lax as lax_internal +from jax._src.tree_util import ( + Partial, + tree_flatten, + tree_map, + tree_structure, + tree_transpose, + tree_unflatten, +) +from jax._src.util import safe_map, wraps from desc.backend import jax, jnp @@ -17,6 +40,8 @@ _parse_input_dimensions, ) +_dtype = partial(dtypes.dtype, canonicalize=True) + # The following section of this code is derived from the NetKet project # https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ # netket/jax/_chunk_utils.py @@ -320,3 +345,382 @@ def wrapped(*args, **kwargs): return jnp.expand_dims(result, axis=dims_to_expand) return wrapped + + +# The following section of this code is derived from JAX +# https://github.com/jax-ml/jax/blob/ff0a98a2aef958df156ca149809cf532efbbcaf4/ +# jax/_src/api.py +# +# The original copyright notice is as follows +# Copyright 2018 The JAX Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); + + +def jacfwd_chunked( + fun: Callable, + argnums: int | Sequence[int] = 0, + has_aux: bool = False, + *, + chunk_size=None, +) -> Callable: + """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. + + Parameters + ---------- + fun: callable + Function whose Jacobian is to be computed. + argnums: Optional, integer or sequence of integers. + Specifies which positional argument(s) to differentiate with respect to + (default ``0``). + has_aux: Optional, bool. + Indicates whether ``fun`` returns a pair where the first element is considered + the output of the mathematical function to be differentiated and the second + element is auxiliary data. Default False. + chunk_size: int + The size of the batches to pass to vmap. If None, defaults to the largest + possible chunk_size. + + Returns + ------- + jac: callable + A function with the same arguments as ``fun``, that evaluates the Jacobian of + ``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True + then a pair of (jacobian, auxiliary_data) is returned. + + """ + check_callable(fun) + argnums = _ensure_index(argnums) + + docstr = ( + "Jacobian of {fun} with respect to positional argument(s) " + "{argnums}. Takes the same arguments as {fun} but returns the " + "jacobian of the output with respect to the arguments at " + "positions {argnums}." + ) + + @wraps(fun, docstr=docstr, argnums=argnums) + def jacfun(*args, **kwargs): + f = lu.wrap_init(fun, kwargs) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False + ) + tree_map(_check_input_dtype_jacfwd, dyn_args) + if not has_aux: + pushfwd: Callable = partial(_jvp, f_partial, dyn_args) + y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args)) + y = tree_map(lambda x: x[0], y) + jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac) + else: + pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True) + y, jac, aux = vmap_chunked(pushfwd, chunk_size=chunk_size)( + _std_basis(dyn_args) + ) + y = tree_map(lambda x: x[0], y) + jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac) + aux = tree_map(lambda x: x[0], aux) + + example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args + jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac) + if not has_aux: + return jac_tree + else: + return jac_tree, aux + + return jacfun + + +def jacrev_chunked( + fun: Callable, + argnums: int | Sequence[int] = 0, + has_aux: bool = False, + *, + chunk_size=None, +) -> Callable: + """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. + + Parameters + ---------- + fun: callable + Function whose Jacobian is to be computed. + argnums: Optional, integer or sequence of integers. + Specifies which positional argument(s) to differentiate with respect to + (default ``0``). + has_aux: Optional, bool. + Indicates whether ``fun`` returns a pair where the first element is considered + the output of the mathematical function to be differentiated and the second + element is auxiliary data. Default False. + chunk_size: int + The size of the batches to pass to vmap. If None, defaults to the largest + possible chunk_size. + + Returns + ------- + jac: callable + A function with the same arguments as ``fun``, that evaluates the Jacobian of + ``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True + then a pair of (jacobian, auxiliary_data) is returned. + + """ + check_callable(fun) + + docstr = ( + "Jacobian of {fun} with respect to positional argument(s) " + "{argnums}. Takes the same arguments as {fun} but returns the " + "jacobian of the output with respect to the arguments at " + "positions {argnums}." + ) + + @wraps(fun, docstr=docstr, argnums=argnums) + def jacfun(*args, **kwargs): + f = lu.wrap_init(fun, kwargs) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False + ) + tree_map(_check_input_dtype_jacrev, dyn_args) + if not has_aux: + y, pullback = _vjp(f_partial, *dyn_args) + else: + y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) + tree_map(_check_output_dtype_jacrev, y) + jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y)) + jac = jac[0] if isinstance(argnums, int) else jac + example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args + jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac) + jac_tree = tree_transpose( + tree_structure(example_args), tree_structure(y), jac_tree + ) + if not has_aux: + return jac_tree + else: + return jac_tree, aux + + return jacfun + + +def _check_input_dtype_jacrev(x): + dispatch.check_arg(x) + aval = core.get_aval(x) + if ( + dtypes.issubdtype(aval.dtype, dtypes.extended) + or dtypes.issubdtype(aval.dtype, np.integer) + or dtypes.issubdtype(aval.dtype, np.bool_) + ): + raise TypeError( + f"jacrev_chunked requires real- or complex-valued inputs (input dtype " + f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " + "If you want to use Boolean- or integer-valued inputs, use vjp " + "or set allow_int to True." + ) + elif not dtypes.issubdtype(aval.dtype, np.inexact): + raise TypeError( + f"jacrev_chunked requires numerical-valued inputs (input dtype that is a " + f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}." + ) + + +def _check_output_dtype_jacrev(x): + aval = core.get_aval(x) + if dtypes.issubdtype(aval.dtype, dtypes.extended): + raise TypeError(f"jacrev_chunked with output element type {aval.dtype.name}") + elif dtypes.issubdtype(aval.dtype, np.complexfloating): + raise TypeError( + f"jacrev_chunked requires real-valued outputs (output dtype that is " + f"a sub-dtype of np.floating), but got {aval.dtype.name}. " + "For holomorphic differentiation, pass holomorphic=True. " + "For differentiation of non-holomorphic functions involving complex " + "outputs, use jax.vjp directly." + ) + elif not dtypes.issubdtype(aval.dtype, np.floating): + raise TypeError( + f"jacrev_chunked requires real-valued outputs (output dtype that is " + f"a sub-dtype of np.floating), but got {aval.dtype.name}. " + "For differentiation of functions with integer outputs, use " + "jax.vjp directly." + ) + + +def _check_input_dtype_jacfwd(x: Any) -> None: + dispatch.check_arg(x) + aval = core.get_aval(x) + if dtypes.issubdtype(aval.dtype, dtypes.extended): + raise TypeError(f"jacfwd with input element type {aval.dtype.name}") + elif not dtypes.issubdtype(aval.dtype, np.floating): + raise TypeError( + "jacfwd requires real-valued inputs (input dtype that is " + f"a sub-dtype of np.floating), but got {aval.dtype.name}. " + "For holomorphic differentiation, pass holomorphic=True. " + "For differentiation of non-holomorphic functions involving " + "complex inputs or integer inputs, use jax.jvp directly." + ) + + +def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr): + return _unravel_array_into_pytree(input_pytree, -1, output_pytree_leaf, arr) + + +def _jacrev_unravel(output_pytree, input_pytree_leaf, arr): + return _unravel_array_into_pytree(output_pytree, 0, input_pytree_leaf, arr) + + +def _possible_downcast(x, example): + if dtypes.issubdtype(x.dtype, np.complexfloating) and not dtypes.issubdtype( + _dtype(example), np.complexfloating + ): + x = x.real + dtype = None if example is None else _dtype(example) + weak_type = None if example is None else dtypes.is_weakly_typed(example) + return lax_internal._convert_element_type(x, dtype, weak_type) + + +def _std_basis(pytree): + leaves, _ = tree_flatten(pytree) + ndim = sum(safe_map(np.size, leaves)) + dtype = dtypes.result_type(*leaves) + flat_basis = jnp.eye(ndim, dtype=dtype) + return _unravel_array_into_pytree(pytree, 1, None, flat_basis) + + +def _unravel_array_into_pytree(pytree, axis, example, arr): + """Unravel an array into a PyTree with a given structure. + + Parameters + ---------- + pytree: The pytree that provides the structure. + axis: The parameter axis is either -1, 0, or 1. It controls the + resulting shapes. + example: If specified, cast the components to the matching dtype/weak_type, + or else use the pytree leaf type if example is None. + arr: The array to be unraveled. + """ + leaves, treedef = tree_flatten(pytree) + axis = axis % arr.ndim + shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1 :] for l in leaves] + parts = _split(arr, np.cumsum(safe_map(np.size, leaves[:-1])), axis) + reshaped_parts = [ + _possible_downcast(np.reshape(x, shape), leaf if example is None else example) + for x, shape, leaf in zip(parts, shapes, leaves) + ] + return tree_unflatten(treedef, reshaped_parts) + + +def _split(x, indices, axis): + if isinstance(x, np.ndarray): + return np.split(x, indices, axis) + else: + return x._split(indices, axis) + + +def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): + """Variant of jvp() that takes an lu.WrappedFun.""" + if not isinstance(primals, (tuple, list)) or not isinstance( + tangents, (tuple, list) + ): + raise TypeError( + "primal and tangent arguments to jax.jvp must be tuples or lists; " + f"found {type(primals).__name__} and {type(tangents).__name__}." + ) + + ps_flat, tree_def = tree_flatten(primals) + ts_flat, tree_def_2 = tree_flatten(tangents) + if tree_def != tree_def_2: + raise TypeError( + "primal and tangent arguments to jax.jvp must have the same tree " + f"structure; primals have tree structure {tree_def} whereas tangents have " + f"tree structure {tree_def_2}." + ) + for p, t in zip(ps_flat, ts_flat): + if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): + raise TypeError( + "primal and tangent arguments to jax.jvp do not match; " + "dtypes must be equal, or in case of int/bool primal dtype " + "the tangent dtype must be float0." + f"Got primal dtype {_dtype(p)} and so expected tangent dtype " + f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got " + f"tangent dtype {_dtype(t)} instead." + ) + if np.shape(p) != np.shape(t): + raise ValueError( + "jvp called with different primal and tangent shapes;" + f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}" + ) + + if not has_aux: + flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def) + out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat) + out_tree = out_tree() + return ( + tree_unflatten(out_tree, out_primals), + tree_unflatten(out_tree, out_tangents), + ) + else: + flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def) + jvp_fun, aux = ad.jvp(flat_fun, has_aux=True) + out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat) + out_tree, aux_tree = out_aux_trees() + return ( + tree_unflatten(out_tree, out_primals), + tree_unflatten(out_tree, out_tangents), + tree_unflatten(aux_tree, aux()), + ) + + +def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): + """Variant of vjp() that takes an lu.WrappedFun.""" + primals_flat, in_tree = tree_flatten(primals) + for arg in primals_flat: + dispatch.check_arg(arg) + if not has_aux: + flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + out_primals, vjp = ad.vjp(flat_fun, primals_flat) + out_tree = out_tree() + else: + flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) + out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) + out_tree, aux_tree = out_aux_trees() + out_primal_avals = map(shaped_abstractify, out_primals) + out_primal_py = tree_unflatten(out_tree, out_primals) + vjp_py = Partial( + partial( + _vjp_pullback_wrapper, fun.__name__, out_primal_avals, (out_tree, in_tree) + ), + vjp, + ) + if not has_aux: + return out_primal_py, vjp_py + else: + return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) + + +def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): + (py_args,) = py_args_ + in_tree_expected, out_tree = io_tree + args, in_tree = tree_flatten(py_args) + if in_tree != in_tree_expected: + raise ValueError( + f"unexpected tree structure of argument to vjp function: " + f"got {in_tree}, but expected to match {in_tree_expected}" + ) + for arg, aval in zip(args, out_primal_avals): + ct_aval = shaped_abstractify(arg) + try: + ct_aval_expected = aval.to_tangent_type() + except AttributeError: + # https://github.com/jax-ml/jax/commit/018189491bde26fe9c7ade1213c5cbbad8bca1c6 + ct_aval_expected = aval.at_least_vspace() + if not core.typecompat( + ct_aval, ct_aval_expected + ) and not _temporary_dtype_exception(ct_aval, ct_aval_expected): + raise ValueError( + "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: " + f"got {ct_aval.str_short()}, but expected " + f"{ct_aval_expected.str_short()} because the corresponding output " + f"of the function {name} had JAX type {aval.str_short()}" + ) + ans = fun(*args) + return tree_unflatten(out_tree, ans) + + +def _temporary_dtype_exception(a, a_) -> bool: + if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): + return a.shape == a_.shape and a_.dtype == dtypes.float0 + return False diff --git a/desc/derivatives.py b/desc/derivatives.py index 684ea37d4..367dc1d63 100644 --- a/desc/derivatives.py +++ b/desc/derivatives.py @@ -10,6 +10,8 @@ if use_jax: import jax + from desc.batching import jacfwd_chunked, jacrev_chunked + class _Derivative(ABC): """_Derivative is an abstract base class for derivative matrix calculations. @@ -123,11 +125,11 @@ class AutoDiffDerivative(_Derivative): """ - def __init__(self, fun, argnum=0, mode="fwd", **kwargs): + def __init__(self, fun, argnum=0, mode="fwd", chunk_size=None, **kwargs): self._fun = fun self._argnum = argnum - + self._chunk_size = chunk_size self._set_mode(mode) def compute(self, *args, **kwargs): @@ -323,9 +325,13 @@ def _set_mode(self, mode) -> None: self._mode = mode if self._mode == "fwd": - self._compute = jax.jacfwd(self._fun, self._argnum) + self._compute = jacfwd_chunked( + self._fun, self._argnum, chunk_size=self._chunk_size + ) elif self._mode == "rev": - self._compute = jax.jacrev(self._fun, self._argnum) + self._compute = jacrev_chunked( + self._fun, self._argnum, chunk_size=self._chunk_size + ) elif self._mode == "grad": self._compute = jax.grad(self._fun, self._argnum) elif self._mode == "hess": diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 29c2f909d..e5f50eee5 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -77,13 +77,17 @@ def test_fun(x, y, a): y = np.array([60, 1, 100, 0.02]) a = -2 - jac_AD = AutoDiffDerivative(test_fun, argnum=0) - J_AD = jac_AD.compute(x, y, a) + jacf_AD = AutoDiffDerivative(test_fun, argnum=0, mode="fwd") + Jf_AD = jacf_AD.compute(x, y, a) + jacr_AD = AutoDiffDerivative(test_fun, argnum=0, mode="rev") + Jr_AD = jacr_AD.compute(x, y, a) - jac_FD = AutoDiffDerivative(test_fun, argnum=0) + jac_FD = FiniteDiffDerivative(test_fun, argnum=0) J_FD = jac_FD.compute(x, y, a) - np.testing.assert_allclose(J_FD, J_AD, atol=1e-8) + np.testing.assert_allclose(Jf_AD, Jr_AD, atol=1e-8) + np.testing.assert_allclose(J_FD, Jf_AD, rtol=1e-2) + np.testing.assert_allclose(J_FD, Jr_AD, rtol=1e-2) @pytest.mark.unit def test_fd_hessian(self): From 1e0d099463e468e45fb7064793aa91a6f56b69ba Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 27 Sep 2024 18:42:05 -0400 Subject: [PATCH 2/6] Pass chunk size to _Objective jacobians --- desc/objectives/objective_funs.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 9a540d0b5..47368b6e6 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1167,25 +1167,34 @@ def hess(self, *args, **kwargs): def jac_scaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled wrt x.""" argnums = tuple(range(len(self.things))) - return Derivative(self.compute_scaled, argnums, mode=self._deriv_mode)( - *args, **kwargs - ) + return Derivative( + self.compute_scaled, + argnums, + mode=self._deriv_mode, + chunk_size=self._jac_chunk_size, + )(*args, **kwargs) @jit def jac_scaled_error(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" argnums = tuple(range(len(self.things))) - return Derivative(self.compute_scaled_error, argnums, mode=self._deriv_mode)( - *args, **kwargs - ) + return Derivative( + self.compute_scaled_error, + argnums, + mode=self._deriv_mode, + chunk_size=self._jac_chunk_size, + )(*args, **kwargs) @jit def jac_unscaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" argnums = tuple(range(len(self.things))) - return Derivative(self.compute_unscaled, argnums, mode=self._deriv_mode)( - *args, **kwargs - ) + return Derivative( + self.compute_unscaled, + argnums, + mode=self._deriv_mode, + chunk_size=self._jac_chunk_size, + )(*args, **kwargs) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) From 8c7805a1ac511787770bc29330212e68d929f6c7 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 27 Sep 2024 18:42:59 -0400 Subject: [PATCH 3/6] Allow _Objective.jvp to work with reverse mode objectives --- desc/objectives/objective_funs.py | 31 +++++++++++++++++--------- tests/test_objective_funs.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 47368b6e6..9b456f4fd 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -13,6 +13,7 @@ jit, jnp, tree_flatten, + tree_map, tree_unflatten, use_jax, ) @@ -1196,17 +1197,27 @@ def jac_unscaled(self, *args, **kwargs): chunk_size=self._jac_chunk_size, )(*args, **kwargs) - def _jvp(self, v, x, constants=None, op="compute_scaled"): + def _jvp(self, v, x, constants=None, op="scaled"): v = v if isinstance(v, (tuple, list)) else (v,) x = x if isinstance(x, (tuple, list)) else (x,) assert len(x) == len(v) - fun = lambda *x: getattr(self, op)(*x, constants=constants) - jvpfun = lambda *dx: Derivative.compute_jvp(fun, tuple(range(len(x))), dx, *x) - sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" - return batched_vectorize( - jvpfun, signature=sig, chunk_size=self._jac_chunk_size - )(*v) + if self._deriv_mode == "fwd": + fun = lambda *x: getattr(self, "compute_" + op)(*x, constants=constants) + jvpfun = lambda *dx: Derivative.compute_jvp( + fun, tuple(range(len(x))), dx, *x + ) + sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" + return batched_vectorize( + jvpfun, signature=sig, chunk_size=self._jac_chunk_size + )(*v) + else: # rev mode. We compute full jacobian and manually do mv. In this case + # the jacobian should be wide so this isn't very expensive. + jac = getattr(self, "jac_" + op)(*x, constants=constants) + # jac is a tuple, 1 array for each thing + Jv = tree_map(jnp.dot, jac, v) + # sum over different things + return jnp.sum(jnp.asarray(Jv), axis=0) @jit def jvp_scaled(self, v, x, constants=None): @@ -1222,7 +1233,7 @@ def jvp_scaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._jvp(v, x, constants, "compute_scaled") + return self._jvp(v, x, constants, "scaled") @jit def jvp_scaled_error(self, v, x, constants=None): @@ -1238,7 +1249,7 @@ def jvp_scaled_error(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._jvp(v, x, constants, "compute_scaled_error") + return self._jvp(v, x, constants, "scaled_error") @jit def jvp_unscaled(self, v, x, constants=None): @@ -1254,7 +1265,7 @@ def jvp_unscaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._jvp(v, x, constants, "compute_unscaled") + return self._jvp(v, x, constants, "unscaled") def print_value(self, args, args0=None, **kwargs): """Print the value of the objective.""" diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 68de417ea..f384b90aa 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -1365,6 +1365,43 @@ def test_derivative_modes(): np.testing.assert_allclose(H1, H3, atol=1e-10) +@pytest.mark.unit +def test_fwd_rev(): + """Test that forward and reverse mode jvps etc give same results.""" + eq = Equilibrium() + obj1 = MeanCurvature(eq, deriv_mode="fwd") + obj2 = MeanCurvature(eq, deriv_mode="rev") + obj1.build() + obj2.build() + + x = eq.pack_params(eq.params_dict) + J1 = obj1.jac_scaled(x) + J2 = obj2.jac_scaled(x) + np.testing.assert_allclose(J1, J2, atol=1e-14) + + jvp1 = obj1.jvp_scaled(x, jnp.ones_like(x)) + jvp2 = obj2.jvp_scaled(x, jnp.ones_like(x)) + np.testing.assert_allclose(jvp1, jvp2, atol=1e-14) + + surf = FourierRZToroidalSurface() + obj1 = PlasmaVesselDistance(eq, surf, deriv_mode="fwd") + obj2 = PlasmaVesselDistance(eq, surf, deriv_mode="rev") + obj1.build() + obj2.build() + + x1 = eq.pack_params(eq.params_dict) + x2 = surf.pack_params(surf.params_dict) + + J1a, J1b = obj1.jac_scaled(x1, x2) + J2a, J2b = obj2.jac_scaled(x1, x2) + np.testing.assert_allclose(J1a, J2a, atol=1e-14) + np.testing.assert_allclose(J1b, J2b, atol=1e-14) + + jvp1 = obj1.jvp_scaled((x1, x2), (jnp.ones_like(x1), jnp.ones_like(x2))) + jvp2 = obj2.jvp_scaled((x1, x2), (jnp.ones_like(x1), jnp.ones_like(x2))) + np.testing.assert_allclose(jvp1, jvp2, atol=1e-14) + + @pytest.mark.unit def test_getter_setter(): """Test getter and setter methods of Objectives.""" From cec34d77193985689a55229345f0692a2c09f2ab Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 27 Sep 2024 21:37:18 -0400 Subject: [PATCH 4/6] Use blocked logic for ObjectiveFunction.jvp --- desc/derivatives.py | 25 ++--- desc/objectives/objective_funs.py | 134 ++++++++++++-------------- desc/optimize/_constraint_wrappers.py | 6 -- desc/utils.py | 9 ++ tests/test_objective_funs.py | 6 ++ 5 files changed, 90 insertions(+), 90 deletions(-) diff --git a/desc/derivatives.py b/desc/derivatives.py index 367dc1d63..d851bf7b5 100644 --- a/desc/derivatives.py +++ b/desc/derivatives.py @@ -6,6 +6,7 @@ from termcolor import colored from desc.backend import jnp, put, use_jax +from desc.utils import ensure_tuple if use_jax: import jax @@ -207,7 +208,7 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs): """ _ = kwargs.pop("rel_step", None) # unused by autodiff argnum = (argnum,) if jnp.isscalar(argnum) else tuple(argnum) - v = (v,) if not isinstance(v, (tuple, list)) else v + v = ensure_tuple(v) def _fun(*x): _args = list(args) @@ -243,14 +244,14 @@ def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs): """ if np.isscalar(argnum1): - v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1 + v1 = ensure_tuple(v1) argnum1 = (argnum1,) else: v1 = tuple(v1) if np.isscalar(argnum2): argnum2 = (argnum2 + 1,) - v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2 + v2 = ensure_tuple(v2) else: argnum2 = tuple([i + 1 for i in argnum2]) v2 = tuple(v2) @@ -286,21 +287,21 @@ def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwarg """ if np.isscalar(argnum1): - v1 = (v1,) if not isinstance(v1, (tuple, list)) else v1 + v1 = ensure_tuple(v1) argnum1 = (argnum1,) else: v1 = tuple(v1) if np.isscalar(argnum2): argnum2 = (argnum2 + 1,) - v2 = (v2,) if not isinstance(v2, (tuple, list)) else v2 + v2 = ensure_tuple(v2) else: argnum2 = tuple([i + 1 for i in argnum2]) v2 = tuple(v2) if np.isscalar(argnum3): argnum3 = (argnum3 + 2,) - v3 = (v3,) if not isinstance(v3, (tuple, list)) else v3 + v3 = ensure_tuple(v3) else: argnum3 = tuple([i + 2 for i in argnum3]) v3 = tuple(v3) @@ -518,7 +519,7 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs): argnum = (argnum,) else: nargs = len(argnum) - v = (v,) if not isinstance(v, tuple) else v + v = ensure_tuple(v) f = np.array( [ @@ -555,14 +556,14 @@ def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args, **kwargs): """ if np.isscalar(argnum1): - v1 = (v1,) if not isinstance(v1, tuple) else v1 + v1 = ensure_tuple(v1) argnum1 = (argnum1,) else: v1 = tuple(v1) if np.isscalar(argnum2): argnum2 = (argnum2 + 1,) - v2 = (v2,) if not isinstance(v2, tuple) else v2 + v2 = ensure_tuple(v2) else: argnum2 = tuple([i + 1 for i in argnum2]) v2 = tuple(v2) @@ -598,21 +599,21 @@ def compute_jvp3(cls, fun, argnum1, argnum2, argnum3, v1, v2, v3, *args, **kwarg """ if np.isscalar(argnum1): - v1 = (v1,) if not isinstance(v1, tuple) else v1 + v1 = ensure_tuple(v1) argnum1 = (argnum1,) else: v1 = tuple(v1) if np.isscalar(argnum2): argnum2 = (argnum2 + 1,) - v2 = (v2,) if not isinstance(v2, tuple) else v2 + v2 = ensure_tuple(v2) else: argnum2 = tuple([i + 1 for i in argnum2]) v2 = tuple(v2) if np.isscalar(argnum3): argnum3 = (argnum3 + 2,) - v3 = (v3,) if not isinstance(v3, tuple) else v3 + v3 = ensure_tuple(v3) else: argnum3 = tuple([i + 2 for i in argnum3]) v3 = tuple(v3) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 9b456f4fd..fd08d1884 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -24,6 +24,7 @@ from desc.utils import ( PRINT_WIDTH, Timer, + ensure_tuple, errorif, flatten_list, is_broadcastable, @@ -501,15 +502,36 @@ def hess(self, x, constants=None): Derivative(self.compute_scalar, mode="hess")(x, constants).squeeze() ) - def _jac_blocked(self, op, x, constants=None): - # could also do something similar for grad and hess, but probably not - # worth it. grad is already super cheap to eval all at once, and blocked - # hess would only be block diag which may miss important interactions. + @jit + def jac_scaled(self, x, constants=None): + """Compute Jacobian matrix of self.compute_scaled wrt x.""" + v = jnp.eye(x.shape[0]) + return self.jvp_scaled(v, x, constants).T + + @jit + def jac_scaled_error(self, x, constants=None): + """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" + v = jnp.eye(x.shape[0]) + return self.jvp_scaled_error(v, x, constants).T + + @jit + def jac_unscaled(self, x, constants=None): + """Compute Jacobian matrix of self.compute_unscaled wrt x.""" + v = jnp.eye(x.shape[0]) + return self.jvp_unscaled(v, x, constants).T + + def _jvp_blocked(self, v, x, constants=None, op="scaled"): + v = ensure_tuple(v) + if len(v) > 1: + # using blocked for higher order derivatives is a pain, and only really + # is needed for perturbations. Just pass that to jvp_batched for now + return self._jvp_batched(v, x, constants, op) if constants is None: constants = self.constants xs_splits = np.cumsum([t.dim_x for t in self.things]) xs = jnp.split(x, xs_splits) + vs = jnp.split(v[0], xs_splits, axis=-1) J = [] assert len(self.objectives) == len(self.constants) # basic idea is we compute the jacobian of each objective wrt each thing @@ -519,63 +541,18 @@ def _jac_blocked(self, op, x, constants=None): # get the xs that go to that objective thing_idx = self._things_per_objective_idx[k] xi = [xs[i] for i in thing_idx] - Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things - Ji = [] # jac wrt all things - for i, thing in enumerate(self.things): - if i in thing_idx: # dfi/dxj != 0 - Ji += [Ji_[thing_idx.index(i)]] - else: # dfi/dxj == 0 - Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] - Ji = jnp.hstack(Ji) # something like [df1/dx1, df1/dx2, 0] - J += [Ji] - # something like [df1/dx1, df1/dx2, 0] - # [df2/dx1, 0, df2/dx3] # noqa:E800 - J = jnp.vstack(J) + vi = [vs[i] for i in thing_idx] + Ji_ = getattr(obj, "jvp_" + op)(vi, xi, constants=const) + J += [Ji_] + # this is the transpose of the jvp when v is a matrix, for consistency with + # jvp_batched + J = jnp.hstack(J) return J - @jit - def jac_scaled(self, x, constants=None): - """Compute Jacobian matrix of self.compute_scaled wrt x.""" - if constants is None: - constants = self.constants - - if self._deriv_mode == "batched": - J = Derivative(self.compute_scaled, mode="fwd")(x, constants) - if self._deriv_mode == "blocked": - J = self._jac_blocked("jac_scaled", x, constants) + def _jvp_batched(self, v, x, constants=None, op="scaled"): + v = ensure_tuple(v) - return jnp.atleast_2d(J.squeeze()) - - @jit - def jac_scaled_error(self, x, constants=None): - """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" - if constants is None: - constants = self.constants - - if self._deriv_mode == "batched": - J = Derivative(self.compute_scaled_error, mode="fwd")(x, constants) - if self._deriv_mode == "blocked": - J = self._jac_blocked("jac_scaled_error", x, constants) - - return jnp.atleast_2d(J.squeeze()) - - @jit - def jac_unscaled(self, x, constants=None): - """Compute Jacobian matrix of self.compute_unscaled wrt x.""" - if constants is None: - constants = self.constants - - if self._deriv_mode == "batched": - J = Derivative(self.compute_unscaled, mode="fwd")(x, constants) - if self._deriv_mode == "blocked": - J = self._jac_blocked("jac_unscaled", x, constants) - - return jnp.atleast_2d(J.squeeze()) - - def _jvp(self, v, x, constants=None, op="compute_scaled"): - v = v if isinstance(v, (tuple, list)) else (v,) - - fun = lambda x: getattr(self, op)(x, constants) + fun = lambda x: getattr(self, "compute_" + op)(x, constants) if len(v) == 1: jvpfun = lambda dx: Derivative.compute_jvp(fun, 0, dx, x) return batched_vectorize( @@ -613,7 +590,11 @@ def jvp_scaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._jvp(v, x, constants, "compute_scaled") + if self._deriv_mode == "batched": + J = self._jvp_batched(v, x, constants, "scaled") + if self._deriv_mode == "blocked": + J = self._jvp_blocked(v, x, constants, "scaled") + return J @jit def jvp_scaled_error(self, v, x, constants=None): @@ -630,7 +611,11 @@ def jvp_scaled_error(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._jvp(v, x, constants, "compute_scaled_error") + if self._deriv_mode == "batched": + J = self._jvp_batched(v, x, constants, "scaled_error") + if self._deriv_mode == "blocked": + J = self._jvp_blocked(v, x, constants, "scaled_error") + return J @jit def jvp_unscaled(self, v, x, constants=None): @@ -647,10 +632,14 @@ def jvp_unscaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._jvp(v, x, constants, "compute_unscaled") + if self._deriv_mode == "batched": + J = self._jvp_batched(v, x, constants, "unscaled") + if self._deriv_mode == "blocked": + J = self._jvp_blocked(v, x, constants, "unscaled") + return J - def _vjp(self, v, x, constants=None, op="compute_scaled"): - fun = lambda x: getattr(self, op)(x, constants) + def _vjp(self, v, x, constants=None, op="scaled"): + fun = lambda x: getattr(self, "compute_" + op)(x, constants) return Derivative.compute_vjp(fun, 0, v, x) @jit @@ -667,7 +656,7 @@ def vjp_scaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._vjp(v, x, constants, "compute_scaled") + return self._vjp(v, x, constants, "scaled") @jit def vjp_scaled_error(self, v, x, constants=None): @@ -683,7 +672,7 @@ def vjp_scaled_error(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._vjp(v, x, constants, "compute_scaled_error") + return self._vjp(v, x, constants, "scaled_error") @jit def vjp_unscaled(self, v, x, constants=None): @@ -699,7 +688,7 @@ def vjp_unscaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - return self._vjp(v, x, constants, "compute_unscaled") + return self._vjp(v, x, constants, "unscaled") def compile(self, mode="auto", verbose=1): """Call the necessary functions to ensure the function is compiled. @@ -1198,8 +1187,8 @@ def jac_unscaled(self, *args, **kwargs): )(*args, **kwargs) def _jvp(self, v, x, constants=None, op="scaled"): - v = v if isinstance(v, (tuple, list)) else (v,) - x = x if isinstance(x, (tuple, list)) else (x,) + v = ensure_tuple(v) + x = ensure_tuple(x) assert len(x) == len(v) if self._deriv_mode == "fwd": @@ -1214,10 +1203,11 @@ def _jvp(self, v, x, constants=None, op="scaled"): else: # rev mode. We compute full jacobian and manually do mv. In this case # the jacobian should be wide so this isn't very expensive. jac = getattr(self, "jac_" + op)(*x, constants=constants) - # jac is a tuple, 1 array for each thing - Jv = tree_map(jnp.dot, jac, v) - # sum over different things - return jnp.sum(jnp.asarray(Jv), axis=0) + # jac is a tuple, 1 array for each thing. Transposes here and below make it + # equivalent to fwd mode above, which batches over the first axis + Jv = tree_map(lambda a, b: jnp.dot(a, b.T), jac, v) + # sum over different things. + return jnp.sum(jnp.asarray(Jv), axis=0).T @jit def jvp_scaled(self, v, x, constants=None): diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 604609fc8..bbc5daf4e 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -289,12 +289,6 @@ def hess(self, x_reduced, constants=None): def _jac(self, x_reduced, constants=None, op="scaled"): x = self.recover(x_reduced) - if self._objective._deriv_mode == "blocked": - fun = getattr(self._objective, "jac_" + op) - return fun(x, constants)[:, self._unfixed_idx] @ ( - self._Z * self._D[self._unfixed_idx, None] - ) - v = self._unfixed_idx_mat df = getattr(self._objective, "jvp_" + op)(v.T, x, constants) return df.T diff --git a/desc/utils.py b/desc/utils.py index 203e29d3d..50c1db1b5 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -929,3 +929,12 @@ def tupleset(t, i, value): ) return res + + +def ensure_tuple(x): + """Returns x as a tuple of arrays.""" + if isinstance(x, tuple): + return x + if isinstance(x, list): + return tuple(x) + return (x,) diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index f384b90aa..cfbf09090 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -1343,6 +1343,7 @@ def test_derivative_modes(): assert obj1._jac_chunk_size > 0 obj3.build() x = obj1.x(eq, surf) + v = jnp.ones_like(x) g1 = obj1.grad(x) g2 = obj2.grad(x) g3 = obj3.grad(x) @@ -1363,6 +1364,11 @@ def test_derivative_modes(): H3 = obj3.hess(x) np.testing.assert_allclose(H1, H2, atol=1e-10) np.testing.assert_allclose(H1, H3, atol=1e-10) + j1 = obj1.jvp_scaled(v, x) + j2 = obj2.jvp_scaled(v, x) + j3 = obj3.jvp_scaled(v, x) + np.testing.assert_allclose(j1, j2, atol=1e-10) + np.testing.assert_allclose(j1, j3, atol=1e-10) @pytest.mark.unit From 0fae08210255b4afa4ef42e70c18b7059c56b2ed Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 27 Sep 2024 21:43:41 -0400 Subject: [PATCH 5/6] Adjust heuristic for choosing fwd over reverse mode AD --- desc/objectives/objective_funs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index fd08d1884..9cdfb61c2 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -968,11 +968,13 @@ def __init__( def _set_derivatives(self): """Choose derivative mode based on size of inputs/outputs.""" if self._deriv_mode == "auto": - # choose based on shape of jacobian. fwd mode is more memory efficient - # so we prefer that unless the jacobian is really wide + # choose based on shape of jacobian. dim_x is usually an overestimate of + # the true number of DOFs because linear constraints remove some. Also + # fwd mode is more memory efficient so we prefer that unless the jacobian + # is really wide self._deriv_mode = ( "fwd" - if self.dim_f >= 0.5 * sum(t.dim_x for t in self.things) + if self.dim_f >= 0.2 * sum(t.dim_x for t in self.things) else "rev" ) From 9bc9aafee691a5855169f50a44df3024f78d5d68 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Fri, 27 Sep 2024 22:57:27 -0400 Subject: [PATCH 6/6] Remove some extra unneeded stuff in favor of importing from jax, remove type hints incompatible with python 3.9 --- desc/batching.py | 303 ++++++----------------------------------------- 1 file changed, 38 insertions(+), 265 deletions(-) diff --git a/desc/batching.py b/desc/batching.py index c215ad55d..129ef7ec3 100644 --- a/desc/batching.py +++ b/desc/batching.py @@ -2,29 +2,22 @@ import functools from functools import partial -from typing import Any, Callable, Optional, Sequence - -import numpy as np -from jax._src import core, dispatch, dtypes -from jax._src.api_util import ( - _ensure_index, - argnums_partial, - check_callable, - flatten_fun_nokwargs, - flatten_fun_nokwargs2, - shaped_abstractify, +from typing import Callable, Optional + +from jax._src.api import ( + _check_input_dtype_jacfwd, + _check_input_dtype_jacrev, + _check_output_dtype_jacfwd, + _check_output_dtype_jacrev, + _jacfwd_unravel, + _jacrev_unravel, + _jvp, + _std_basis, + _vjp, ) -from jax._src.interpreters import ad -from jax._src.lax import lax as lax_internal -from jax._src.tree_util import ( - Partial, - tree_flatten, - tree_map, - tree_structure, - tree_transpose, - tree_unflatten, -) -from jax._src.util import safe_map, wraps +from jax._src.api_util import _ensure_index, argnums_partial, check_callable +from jax._src.tree_util import tree_map, tree_structure, tree_transpose +from jax._src.util import wraps from desc.backend import jax, jnp @@ -40,8 +33,6 @@ _parse_input_dimensions, ) -_dtype = partial(dtypes.dtype, canonicalize=True) - # The following section of this code is derived from the NetKet project # https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ # netket/jax/_chunk_utils.py @@ -357,12 +348,13 @@ def wrapped(*args, **kwargs): def jacfwd_chunked( - fun: Callable, - argnums: int | Sequence[int] = 0, - has_aux: bool = False, + fun, + argnums=0, + has_aux=False, + holomorphic=False, *, chunk_size=None, -) -> Callable: +): """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. Parameters @@ -376,6 +368,8 @@ def jacfwd_chunked( Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. + holomorphic: Optional, bool. + Indicates whether ``fun`` is promised to be holomorphic. Default False. chunk_size: int The size of the batches to pass to vmap. If None, defaults to the largest possible chunk_size. @@ -404,7 +398,7 @@ def jacfun(*args, **kwargs): f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False ) - tree_map(_check_input_dtype_jacfwd, dyn_args) + tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) if not has_aux: pushfwd: Callable = partial(_jvp, f_partial, dyn_args) y, jac = vmap_chunked(pushfwd, chunk_size=chunk_size)(_std_basis(dyn_args)) @@ -418,7 +412,7 @@ def jacfun(*args, **kwargs): y = tree_map(lambda x: x[0], y) jac = tree_map(lambda x: jnp.moveaxis(x, 0, -1), jac) aux = tree_map(lambda x: x[0], aux) - + tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y) example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac) if not has_aux: @@ -430,12 +424,14 @@ def jacfun(*args, **kwargs): def jacrev_chunked( - fun: Callable, - argnums: int | Sequence[int] = 0, - has_aux: bool = False, + fun, + argnums=0, + has_aux=False, + holomorphic=False, + allow_int=False, *, chunk_size=None, -) -> Callable: +): """Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD. Parameters @@ -449,6 +445,12 @@ def jacrev_chunked( Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. + holomorphic: Optional, bool. + Indicates whether ``fun`` is promised to be holomorphic. Default False. + allow_int: Optional, bool. + Whether to allow differentiating with respect to integer valued inputs. The + gradient of an integer input will have a trivial vector-space dtype (float0). + Default False. chunk_size: int The size of the batches to pass to vmap. If None, defaults to the largest possible chunk_size. @@ -476,12 +478,12 @@ def jacfun(*args, **kwargs): f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False ) - tree_map(_check_input_dtype_jacrev, dyn_args) + tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if not has_aux: y, pullback = _vjp(f_partial, *dyn_args) else: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) - tree_map(_check_output_dtype_jacrev, y) + tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap_chunked(pullback, chunk_size=chunk_size)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args @@ -495,232 +497,3 @@ def jacfun(*args, **kwargs): return jac_tree, aux return jacfun - - -def _check_input_dtype_jacrev(x): - dispatch.check_arg(x) - aval = core.get_aval(x) - if ( - dtypes.issubdtype(aval.dtype, dtypes.extended) - or dtypes.issubdtype(aval.dtype, np.integer) - or dtypes.issubdtype(aval.dtype, np.bool_) - ): - raise TypeError( - f"jacrev_chunked requires real- or complex-valued inputs (input dtype " - f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " - "If you want to use Boolean- or integer-valued inputs, use vjp " - "or set allow_int to True." - ) - elif not dtypes.issubdtype(aval.dtype, np.inexact): - raise TypeError( - f"jacrev_chunked requires numerical-valued inputs (input dtype that is a " - f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}." - ) - - -def _check_output_dtype_jacrev(x): - aval = core.get_aval(x) - if dtypes.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"jacrev_chunked with output element type {aval.dtype.name}") - elif dtypes.issubdtype(aval.dtype, np.complexfloating): - raise TypeError( - f"jacrev_chunked requires real-valued outputs (output dtype that is " - f"a sub-dtype of np.floating), but got {aval.dtype.name}. " - "For holomorphic differentiation, pass holomorphic=True. " - "For differentiation of non-holomorphic functions involving complex " - "outputs, use jax.vjp directly." - ) - elif not dtypes.issubdtype(aval.dtype, np.floating): - raise TypeError( - f"jacrev_chunked requires real-valued outputs (output dtype that is " - f"a sub-dtype of np.floating), but got {aval.dtype.name}. " - "For differentiation of functions with integer outputs, use " - "jax.vjp directly." - ) - - -def _check_input_dtype_jacfwd(x: Any) -> None: - dispatch.check_arg(x) - aval = core.get_aval(x) - if dtypes.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"jacfwd with input element type {aval.dtype.name}") - elif not dtypes.issubdtype(aval.dtype, np.floating): - raise TypeError( - "jacfwd requires real-valued inputs (input dtype that is " - f"a sub-dtype of np.floating), but got {aval.dtype.name}. " - "For holomorphic differentiation, pass holomorphic=True. " - "For differentiation of non-holomorphic functions involving " - "complex inputs or integer inputs, use jax.jvp directly." - ) - - -def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr): - return _unravel_array_into_pytree(input_pytree, -1, output_pytree_leaf, arr) - - -def _jacrev_unravel(output_pytree, input_pytree_leaf, arr): - return _unravel_array_into_pytree(output_pytree, 0, input_pytree_leaf, arr) - - -def _possible_downcast(x, example): - if dtypes.issubdtype(x.dtype, np.complexfloating) and not dtypes.issubdtype( - _dtype(example), np.complexfloating - ): - x = x.real - dtype = None if example is None else _dtype(example) - weak_type = None if example is None else dtypes.is_weakly_typed(example) - return lax_internal._convert_element_type(x, dtype, weak_type) - - -def _std_basis(pytree): - leaves, _ = tree_flatten(pytree) - ndim = sum(safe_map(np.size, leaves)) - dtype = dtypes.result_type(*leaves) - flat_basis = jnp.eye(ndim, dtype=dtype) - return _unravel_array_into_pytree(pytree, 1, None, flat_basis) - - -def _unravel_array_into_pytree(pytree, axis, example, arr): - """Unravel an array into a PyTree with a given structure. - - Parameters - ---------- - pytree: The pytree that provides the structure. - axis: The parameter axis is either -1, 0, or 1. It controls the - resulting shapes. - example: If specified, cast the components to the matching dtype/weak_type, - or else use the pytree leaf type if example is None. - arr: The array to be unraveled. - """ - leaves, treedef = tree_flatten(pytree) - axis = axis % arr.ndim - shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis + 1 :] for l in leaves] - parts = _split(arr, np.cumsum(safe_map(np.size, leaves[:-1])), axis) - reshaped_parts = [ - _possible_downcast(np.reshape(x, shape), leaf if example is None else example) - for x, shape, leaf in zip(parts, shapes, leaves) - ] - return tree_unflatten(treedef, reshaped_parts) - - -def _split(x, indices, axis): - if isinstance(x, np.ndarray): - return np.split(x, indices, axis) - else: - return x._split(indices, axis) - - -def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): - """Variant of jvp() that takes an lu.WrappedFun.""" - if not isinstance(primals, (tuple, list)) or not isinstance( - tangents, (tuple, list) - ): - raise TypeError( - "primal and tangent arguments to jax.jvp must be tuples or lists; " - f"found {type(primals).__name__} and {type(tangents).__name__}." - ) - - ps_flat, tree_def = tree_flatten(primals) - ts_flat, tree_def_2 = tree_flatten(tangents) - if tree_def != tree_def_2: - raise TypeError( - "primal and tangent arguments to jax.jvp must have the same tree " - f"structure; primals have tree structure {tree_def} whereas tangents have " - f"tree structure {tree_def_2}." - ) - for p, t in zip(ps_flat, ts_flat): - if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): - raise TypeError( - "primal and tangent arguments to jax.jvp do not match; " - "dtypes must be equal, or in case of int/bool primal dtype " - "the tangent dtype must be float0." - f"Got primal dtype {_dtype(p)} and so expected tangent dtype " - f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got " - f"tangent dtype {_dtype(t)} instead." - ) - if np.shape(p) != np.shape(t): - raise ValueError( - "jvp called with different primal and tangent shapes;" - f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}" - ) - - if not has_aux: - flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def) - out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat) - out_tree = out_tree() - return ( - tree_unflatten(out_tree, out_primals), - tree_unflatten(out_tree, out_tangents), - ) - else: - flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def) - jvp_fun, aux = ad.jvp(flat_fun, has_aux=True) - out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat) - out_tree, aux_tree = out_aux_trees() - return ( - tree_unflatten(out_tree, out_primals), - tree_unflatten(out_tree, out_tangents), - tree_unflatten(aux_tree, aux()), - ) - - -def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): - """Variant of vjp() that takes an lu.WrappedFun.""" - primals_flat, in_tree = tree_flatten(primals) - for arg in primals_flat: - dispatch.check_arg(arg) - if not has_aux: - flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primals, vjp = ad.vjp(flat_fun, primals_flat) - out_tree = out_tree() - else: - flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) - out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) - out_tree, aux_tree = out_aux_trees() - out_primal_avals = map(shaped_abstractify, out_primals) - out_primal_py = tree_unflatten(out_tree, out_primals) - vjp_py = Partial( - partial( - _vjp_pullback_wrapper, fun.__name__, out_primal_avals, (out_tree, in_tree) - ), - vjp, - ) - if not has_aux: - return out_primal_py, vjp_py - else: - return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) - - -def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): - (py_args,) = py_args_ - in_tree_expected, out_tree = io_tree - args, in_tree = tree_flatten(py_args) - if in_tree != in_tree_expected: - raise ValueError( - f"unexpected tree structure of argument to vjp function: " - f"got {in_tree}, but expected to match {in_tree_expected}" - ) - for arg, aval in zip(args, out_primal_avals): - ct_aval = shaped_abstractify(arg) - try: - ct_aval_expected = aval.to_tangent_type() - except AttributeError: - # https://github.com/jax-ml/jax/commit/018189491bde26fe9c7ade1213c5cbbad8bca1c6 - ct_aval_expected = aval.at_least_vspace() - if not core.typecompat( - ct_aval, ct_aval_expected - ) and not _temporary_dtype_exception(ct_aval, ct_aval_expected): - raise ValueError( - "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: " - f"got {ct_aval.str_short()}, but expected " - f"{ct_aval_expected.str_short()} because the corresponding output " - f"of the function {name} had JAX type {aval.str_short()}" - ) - ans = fun(*args) - return tree_unflatten(out_tree, ans) - - -def _temporary_dtype_exception(a, a_) -> bool: - if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): - return a.shape == a_.shape and a_.dtype == dtypes.float0 - return False