Skip to content

Commit cccc806

Browse files
committed
Update NNX Module class docs in module.py
1 parent 6bc9858 commit cccc806

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

flax/nnx/module.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ class ModuleMeta(ObjectMeta):
5151
class Module(Object, metaclass=ModuleMeta):
5252
"""Base class for all neural network modules.
5353
54-
Layers and models should subclass this class.
54+
Flax NNX layers and models should subclass this :class`flax.nnx.Module` class.
5555
56-
``Module``'s can contain submodules, and in this way can be nested in a tree
57-
structure. Submodules can be assigned as regular attributes inside the
58-
``__init__`` method.
56+
An ``nnx.Module`` can contain sub-``Module``'s, allowing them to be nested in a
57+
JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes
58+
inside the ``__init__`` method.
5959
60-
You can define arbitrary "forward pass" methods on your ``Module`` subclass.
60+
You can define arbitrary "forward pass" methods on your ``nnx.Module`` subclass.
6161
While no methods are special-cased, ``__call__`` is a popular choice since
62-
you can call the ``Module`` directly::
62+
you can call the ``nnx.Module`` directly::
6363
6464
>>> from flax import nnx
6565
>>> import jax.numpy as jnp

0 commit comments

Comments
 (0)