diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6d440dbec6..515358269c 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -932,21 +932,21 @@ def split( def split( self, node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: - """Splits a graph node into a :class:`flax.nnx.GraphDef` and one or more :flax:`flax.nnx.Variable` - :class:`flax.nnx.State`'s. - - It is a way to represent an :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef`` that captures - its Pythonic static information; and 2) one or more ``nnx.Variable`` ``nnx.State``'(s) that capture - its ``jax.Array``'s in the form of JAX pytrees. - - A :class:`flax.nnx.State` is a ``Mapping`` from strings or integers to :class:`flax.nnx.Variable`'s, - ``jax.Array``'s or nested :class:`flax.nnx.State`'s. - - ``nnx.GraphDef`` contains all the static information needed to reconstruct a :class:`flax.nnx.Module` - graph, and it is analogous to JAX’s ``PyTreeDef``. - - :func:`flax.nnx.split` is used in conjunction with :func:`flax.nnx.merge` to - switch seamlessly between stateful and stateless representations of the graph. + """Splits a graph node into a :class:`flax.nnx.GraphDef` and one or more + :flax:`flax.nnx.Variable` :class:`flax.nnx.State`'s. + + It is a way to represent a :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef`` + that captures its Pythonic static information; and 2) one or more ``nnx.Variable`` + ``nnx.State``'(s) that capture its ``jax.Array``'s in the form of JAX pytrees. + + An ``nnx.State`` is a ``Mapping`` from strings or integers to ``nnx.Variable``'s, + JAX arrays (``jax.Array``) or nested ``nnx.State``'s. + + ``nnx.GraphDef`` contains all the static information needed to reconstruct a + ``nnx.Module`` graph, and it is analogous to JAX’s ``PyTreeDef``. + + ``nnx.split`` is used in conjunction with :func:`flax.nnx.merge` to switch seamlessly + between stateful and stateless representations of the graph. Example usage::