diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index fec21add2..11bf14584 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -1297,10 +1297,18 @@ def merge( /, *states: tp.Mapping[KeyT, tp.Any], ) -> A: - """The inverse of :func:`split`. + """The inverse of :func:`flax.nnx.split`. - ``merge`` takes a :class:`GraphDef` and one or more :class:`State`'s and creates - a new node with the same structure as the original node. + ``nnx.merge`` takes a :class:`flax.nnx.GraphDef` and one or more :class:`flax.nnx.State`'s + and creates a new node with the same structure as the original node. + + Recall: :func:`flax.nnx.split` is used to represent a :class:`flax.nnx.Module` + by: 1) a static ``nnx.GraphDef`` that captures its Pythonic static information; + and 2) one or more :class:`flax.nnx.Variable` ``nnx.State``'(s) that capture + its ``jax.Array``'s in the form of JAX pytrees. + + ``nnx.merge`` is used in conjunction with ``nnx.split`` to switch seamlessly + between stateful and stateless representations of the graph. Example usage:: @@ -1320,17 +1328,17 @@ def merge( >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear) - :func:`split` and :func:`merge` are primarily used to interact directly with JAX - transformations, see - `Functional API `__ + ``nnx.split`` and ``nnx.merge`` are primarily used to interact directly with JAX + transformations (refer to + `Functional API `__ for more information. Args: - graphdef: A :class:`GraphDef` object. - state: A :class:`State` object. - *states: Additional :class:`State` objects. + graphdef: A :class:`flax.nnx.GraphDef` object. + state: A :class:`flax.nnx.State` object. + *states: Additional :class:`flax.nnx.State` objects. Returns: - The merged :class:`Module`. + The merged :class:`flax.nnx.Module`. """ state = State.merge(state, *states) node = unflatten(graphdef, state) @@ -1852,4 +1860,4 @@ def _unflatten_pytree( type(None), flatten=lambda x: ([], None), unflatten=lambda _, __: None, # type: ignore -) \ No newline at end of file +)