Skip to content

Commit

Permalink
Merge pull request #4337 from 8bitmp3:fixup-haiku-to-flax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689936781
  • Loading branch information
Flax Authors committed Oct 25, 2024
2 parents cd5db46 + b83dce7 commit 0f044c8
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Migrating from Haiku to Flax
##########
############################

This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku.

To get the most out of this guide, it is highly recommended to get go through `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__ document, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.
To get the most out of this guide, it is highly recommended to go through `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__ document, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.


.. testsetup:: Haiku, Flax NNX
Expand All @@ -15,15 +15,15 @@ To get the most out of this guide, it is highly recommended to get go through `F


Basic Module Definition
======================
=======================

Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer.

There are two fundamental differences between Haiku and Flax ``Module`` objects:

* **Stateless vs. stateful**: A ``hk.Module`` instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object.

* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).
* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).


.. codediff::
Expand Down Expand Up @@ -449,7 +449,7 @@ Let's start with an example:

Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. In both cases, we use the library's ``scan`` call to run the ``RNNCell`` over the input sequence.

The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>`__, wheras in Haiku you need to transpose the input and output explicitly.
The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>`__, whereas in Haiku you need to transpose the input and output explicitly.

.. codediff::
:title: Haiku, Flax NNX
Expand Down Expand Up @@ -580,7 +580,7 @@ In Flax, model initialization and calling code are completely decoupled, so we u

There are a few other details to explain in the Flax example above:

* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of PRNG state and relies on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside.
* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of the PRNG state and rely on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside.

* Here, you split the PRNG keys because ``jax.vmap`` and ``jax.lax.scan`` require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the ``MLP``, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform.

Expand Down Expand Up @@ -640,7 +640,7 @@ Top-level Haiku functions vs top-level Flax modules

In Haiku, it is possible to write the entire model as a single function by using
the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and
states. It very common to write the top-level "Module" as a function instead.
states. It is very common to write the top-level "Module" as a function instead.

The Flax team recommends a more Module-centric approach that uses ``__call__`` to
define the forward function. In Flax modules, the parameters and variables can
Expand Down Expand Up @@ -690,4 +690,6 @@ be set and accessed as normal using regular Python class semantics.

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)
_, params, counter = nnx.split(model, nnx.Param, Counter)


0 comments on commit 0f044c8

Please sign in to comment.