File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -51,15 +51,15 @@ class ModuleMeta(ObjectMeta):
51
51
class Module (Object , metaclass = ModuleMeta ):
52
52
"""Base class for all neural network modules.
53
53
54
- Layers and models should subclass this class.
54
+ Flax NNX layers and models should subclass this :class`flax.nnx.Module` class.
55
55
56
- `` Module``'s can contain submodules, and in this way can be nested in a tree
57
- structure. Submodules can be assigned as regular attributes inside the
58
- ``__init__`` method.
56
+ An ``nnx. Module`` can contain sub-``Module``'s, allowing them to be nested in a
57
+ JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes
58
+ inside the ``__init__`` method.
59
59
60
- You can define arbitrary "forward pass" methods on your ``Module`` subclass.
60
+ You can define arbitrary "forward pass" methods on your ``nnx. Module`` subclass.
61
61
While no methods are special-cased, ``__call__`` is a popular choice since
62
- you can call the ``Module`` directly::
62
+ you can call the ``nnx. Module`` directly::
63
63
64
64
>>> from flax import nnx
65
65
>>> import jax.numpy as jnp
You can’t perform that action at this time.
0 commit comments