Skip to content

Commit b83dce7

Browse files
committed
Fix typos in Flax NNX Migrating from Haiku to Flax
1 parent dc3017a commit b83dce7

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

docs_nnx/guides/haiku_to_flax.rst

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
Migrating from Haiku to Flax
2-
##########
2+
############################
33

44
This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku.
55

6-
To get the most out of this guide, it is highly recommended to get go through `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__ document, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.
6+
To get the most out of this guide, it is highly recommended to go through `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__ document, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.
77

88

99
.. testsetup:: Haiku, Flax NNX
@@ -15,15 +15,15 @@ To get the most out of this guide, it is highly recommended to get go through `F
1515

1616

1717
Basic Module Definition
18-
======================
18+
=======================
1919

2020
Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer.
2121

2222
There are two fundamental differences between Haiku and Flax ``Module`` objects:
2323

2424
* **Stateless vs. stateful**: A ``hk.Module`` instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object.
2525

26-
* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).
26+
* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).
2727

2828

2929
.. codediff::
@@ -449,7 +449,7 @@ Let's start with an example:
449449

450450
Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. In both cases, we use the library's ``scan`` call to run the ``RNNCell`` over the input sequence.
451451

452-
The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>`__, wheras in Haiku you need to transpose the input and output explicitly.
452+
The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>`__, whereas in Haiku you need to transpose the input and output explicitly.
453453

454454
.. codediff::
455455
:title: Haiku, Flax NNX
@@ -580,7 +580,7 @@ In Flax, model initialization and calling code are completely decoupled, so we u
580580

581581
There are a few other details to explain in the Flax example above:
582582

583-
* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of PRNG state and relies on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside.
583+
* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of the PRNG state and rely on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside.
584584

585585
* Here, you split the PRNG keys because ``jax.vmap`` and ``jax.lax.scan`` require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the ``MLP``, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform.
586586

@@ -640,7 +640,7 @@ Top-level Haiku functions vs top-level Flax modules
640640

641641
In Haiku, it is possible to write the entire model as a single function by using
642642
the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and
643-
states. It very common to write the top-level "Module" as a function instead.
643+
states. It is very common to write the top-level "Module" as a function instead.
644644

645645
The Flax team recommends a more Module-centric approach that uses ``__call__`` to
646646
define the forward function. In Flax modules, the parameters and variables can
@@ -690,4 +690,6 @@ be set and accessed as normal using regular Python class semantics.
690690

691691
model = FooModule(rngs=nnx.Rngs(0))
692692

693-
_, params, counter = nnx.split(model, nnx.Param, Counter)
693+
_, params, counter = nnx.split(model, nnx.Param, Counter)
694+
695+

0 commit comments

Comments
 (0)