Skip to content

Commit

Permalink
Update NNX split docs in graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 15, 2024
1 parent 2d9b14d commit d2cc57a
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,21 +932,21 @@ def split(
def split(
self, node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
"""Splits a graph node into a :class:`flax.nnx.GraphDef` and one or more :flax:`flax.nnx.Variable`
:class:`flax.nnx.State`'s.
It is a way to represent an :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef`` that captures
its Pythonic static information; and 2) one or more ``nnx.Variable`` ``nnx.State``'(s) that capture
its ``jax.Array``'s in the form of JAX pytrees.
A :class:`flax.nnx.State` is a ``Mapping`` from strings or integers to :class:`flax.nnx.Variable`'s,
``jax.Array``'s or nested :class:`flax.nnx.State`'s.
``nnx.GraphDef`` contains all the static information needed to reconstruct a :class:`flax.nnx.Module`
graph, and it is analogous to JAX’s ``PyTreeDef``.
:func:`flax.nnx.split` is used in conjunction with :func:`flax.nnx.merge` to
switch seamlessly between stateful and stateless representations of the graph.
"""Splits a graph node into a :class:`flax.nnx.GraphDef` and one or more
:flax:`flax.nnx.Variable` :class:`flax.nnx.State`'s.
It is a way to represent a :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef``
that captures its Pythonic static information; and 2) one or more ``nnx.Variable``
``nnx.State``'(s) that capture its ``jax.Array``'s in the form of JAX pytrees.
An ``nnx.State`` is a ``Mapping`` from strings or integers to ``nnx.Variable``'s,
JAX arrays (``jax.Array``) or nested ``nnx.State``'s.
``nnx.GraphDef`` contains all the static information needed to reconstruct a
``nnx.Module`` graph, and it is analogous to JAX’s ``PyTreeDef``.
``nnx.split`` is used in conjunction with :func:`flax.nnx.merge` to switch seamlessly
between stateful and stateless representations of the graph.
Example usage::
Expand Down

0 comments on commit d2cc57a

Please sign in to comment.