Skip to content

Commit

Permalink
Update NNX split docs in graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Nov 28, 2024
1 parent 6bc9858 commit 2d9b14d
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,11 +932,21 @@ def split(
def split(
self, node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
seamlessly between stateful and stateless representations of the graph.
"""Splits a graph node into a :class:`flax.nnx.GraphDef` and one or more :flax:`flax.nnx.Variable`
:class:`flax.nnx.State`'s.
It is a way to represent an :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef`` that captures
its Pythonic static information; and 2) one or more ``nnx.Variable`` ``nnx.State``'(s) that capture
its ``jax.Array``'s in the form of JAX pytrees.
A :class:`flax.nnx.State` is a ``Mapping`` from strings or integers to :class:`flax.nnx.Variable`'s,
``jax.Array``'s or nested :class:`flax.nnx.State`'s.
``nnx.GraphDef`` contains all the static information needed to reconstruct a :class:`flax.nnx.Module`
graph, and it is analogous to JAX’s ``PyTreeDef``.
:func:`flax.nnx.split` is used in conjunction with :func:`flax.nnx.merge` to
switch seamlessly between stateful and stateless representations of the graph.
Example usage::
Expand Down Expand Up @@ -989,11 +999,13 @@ def split(
})
Arguments:
node: graph node to split.
*filters: some optional filters to group the state into mutually exclusive substates.
node: A graph node to split.
*filters: Some optional filters (:class:`flax.nnx.filterlib.Filter`) to group the
state into mutually exclusive sub-States (:class:`flax.nnx.State`).
Returns:
:class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
filters are passed, a single :class:`State` is returned.
A :class:`flax.nnx.GraphDef` and one or more :class:`flax.nnx.State`'s equal to the
number of filters (:class:`flax.nnx.filterlib.Filter`) passed. If no filters are
passed, a single :class:`flax.nnx.State` is returned.
"""
ref_index: RefMap[tp.Any, Index] = RefMap()
graphdef, state = flatten(node, ref_index)
Expand Down Expand Up @@ -1852,4 +1864,4 @@ def _unflatten_pytree(
type(None),
flatten=lambda x: ([], None),
unflatten=lambda _, __: None, # type: ignore
)
)

0 comments on commit 2d9b14d

Please sign in to comment.