Skip to content

Commit

Permalink
Update flax.nnx.grad docstring in autodiff.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 16, 2024
1 parent d31f290 commit 7e7946a
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,18 @@ def grad(
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""Lifted version of ``jax.grad`` that can handle Modules / graph nodes as
arguments.
"""A reference-aware version of `jax.grad <https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad>`_
that can handle :class:`flax.nnx.Module`'s / graph nodes as arguments.
The differentiable state of each graph node is defined by the `wrt` filter,
which by default is set to `nnx.Param`. Internally the ``State`` of
graph nodes is extracted, filtered according to `wrt` filter, and
passed to the underlying ``jax.grad`` function. The gradients
of graph nodes are of type ``State``.
Creates a function that evaluates the gradient of a function ``f``.
The differentiable state of each graph node is defined by the ``wrt`` filter,
which by default is set to :class:`flax.nnx.Param`.
Internally, the :class:`flax.nnx.State` of graph nodes is extracted, filtered
according to the ``wrt`` filter, and passed to the underlying ``jax.grad`` function.
The gradients of graph nodes are of type ``nnx.State``.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
Example::
Expand All @@ -255,6 +259,7 @@ def grad(
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
...
>>> jax.tree.map(jnp.shape, grads)
State({
'bias': VariableState(
Expand All @@ -268,28 +273,28 @@ def grad(
})
Args:
fun: Function to be differentiated. Its arguments at positions specified by
f: A function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, graph nodes or standard Python
containers. Argument arrays in the positions specified by ``argnums`` must
be of inexact (i.e., floating-point or complex) type. It should return a
be of inexact (i.e., floating-point or complex) type. Function ``f`` should return a
scalar (which includes arrays with shape ``()`` but not arrays with shape
``(1,)`` etc.)
``(1,)`` etc).
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether function ``f`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
differentiated and the second element is auxiliary data. Default ``False``.
holomorphic: Optional, bool. Indicates whether function ``f`` is promised to be
holomorphic. If ``True``, inputs and outputs must be complex. Default ``False``.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
have a trivial vector-space dtype (``float0``). Default ``False``.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
function ``f`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
is a named batch axis, ``flax.nnx.grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient, while ``flax.nnx.grad(f)`` will create
one that computes the per-example gradient.
"""

Expand Down

0 comments on commit 7e7946a

Please sign in to comment.