Skip to content

Commit

Permalink
Update NNX NodeDef docs in graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 16, 2024
1 parent 6bc9858 commit 240cc81
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ def __treescope_repr__(self, path, subtree_renderer):

@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(GraphDef[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
:class:`Module`. A ``GraphDef`` can be generated by either
calling :func:`split` or :func:`graphdef` on the :class:`Module`."""
"""A dataclass that denotes the JAX pytree structure of a :class:`flax.nnx.Module`.
A :class:`flax.nnx.GraphDef` can be generated by either calling
:func:`flax.nnx.split` or :func:`flax.nnx.graphdef` on the ``nnx.Module``."""

type: tp.Type[Node]
index: int
Expand Down Expand Up @@ -1852,4 +1852,4 @@ def _unflatten_pytree(
type(None),
flatten=lambda x: ([], None),
unflatten=lambda _, __: None, # type: ignore
)
)

0 comments on commit 240cc81

Please sign in to comment.