From 240cc819a02d5a2e151567722e1d60b6f7f6de4a Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Wed, 27 Nov 2024 20:36:43 +0000 Subject: [PATCH] Update NNX NodeDef docs in graph.py --- flax/nnx/graph.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index fec21add..708ece43 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -292,9 +292,9 @@ def __treescope_repr__(self, path, subtree_renderer): @dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): - """A dataclass that denotes the tree structure of a - :class:`Module`. A ``GraphDef`` can be generated by either - calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" + """A dataclass that denotes the JAX pytree structure of a :class:`flax.nnx.Module`. + A :class:`flax.nnx.GraphDef` can be generated by either calling + :func:`flax.nnx.split` or :func:`flax.nnx.graphdef` on the ``nnx.Module``.""" type: tp.Type[Node] index: int @@ -1852,4 +1852,4 @@ def _unflatten_pytree( type(None), flatten=lambda x: ([], None), unflatten=lambda _, __: None, # type: ignore -) \ No newline at end of file +)