diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index b86823c527..97d3ead2d3 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -233,14 +233,18 @@ def grad( tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): - """Lifted version of ``jax.grad`` that can handle Modules / graph nodes as - arguments. + """A reference-aware version of `jax.grad `_ + that can handle :class:`flax.nnx.Module`'s / graph nodes as arguments. - The differentiable state of each graph node is defined by the `wrt` filter, - which by default is set to `nnx.Param`. Internally the ``State`` of - graph nodes is extracted, filtered according to `wrt` filter, and - passed to the underlying ``jax.grad`` function. The gradients - of graph nodes are of type ``State``. + Creates a function that evaluates the gradient of a function ``f``. + + The differentiable state of each graph node is defined by the ``wrt`` filter, + which by default is set to :class:`flax.nnx.Param`. + Internally, the :class:`flax.nnx.State` of graph nodes is extracted, filtered + according to the ``wrt`` filter, and passed to the underlying ``jax.grad`` function. + The gradients of graph nodes are of type ``nnx.State``. + + Learn more in `Flax NNX vs JAX Transformations `_. Example:: @@ -255,6 +259,7 @@ def grad( >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) + ... >>> jax.tree.map(jnp.shape, grads) State({ 'bias': VariableState( @@ -268,28 +273,28 @@ def grad( }) Args: - fun: Function to be differentiated. Its arguments at positions specified by + f: A function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified by ``argnums`` must - be of inexact (i.e., floating-point or complex) type. It should return a + be of inexact (i.e., floating-point or complex) type. Function ``f`` should return a scalar (which includes arrays with shape ``()`` but not arrays with shape - ``(1,)`` etc.) + ``(1,)`` etc). argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default 0). - has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + positional argument(s) to differentiate with respect to (default ``0``). + has_aux: Optional, bool. Indicates whether function ``f`` returns a pair where the first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be - holomorphic. If True, inputs and outputs must be complex. Default False. + differentiated and the second element is auxiliary data. Default ``False``. + holomorphic: Optional, bool. Indicates whether function ``f`` is promised to be + holomorphic. If ``True``, inputs and outputs must be complex. Default ``False``. allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will - have a trivial vector-space dtype (float0). Default False. + have a trivial vector-space dtype (``float0``). Default ``False``. reduce_axes: Optional, tuple of axis names. If an axis is listed here, and - ``fun`` implicitly broadcasts a value over that axis, the backward pass + function ``f`` implicitly broadcasts a value over that axis, the backward pass will perform a ``psum`` of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if ``'batch'`` - is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a - function that computes the total gradient while ``grad(f)`` will create + is a named batch axis, ``flax.nnx.grad(f, reduce_axes=('batch',))`` will create a + function that computes the total gradient, while ``flax.nnx.grad(f)`` will create one that computes the per-example gradient. """