diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst index 3e2a6f2915..4349b53d0b 100644 --- a/docs_nnx/guides/haiku_to_flax.rst +++ b/docs_nnx/guides/haiku_to_flax.rst @@ -1,9 +1,9 @@ Migrating from Haiku to Flax -########## +############################ This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku. -To get the most out of this guide, it is highly recommended to get go through `Flax NNX basics `__ document, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. +To get the most out of this guide, it is highly recommended to go through `Flax NNX basics `__ document, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. .. testsetup:: Haiku, Flax NNX @@ -15,7 +15,7 @@ To get the most out of this guide, it is highly recommended to get go through `F Basic Module Definition -====================== +======================= Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. @@ -23,7 +23,7 @@ There are two fundamental differences between Haiku and Flax ``Module`` objects: * **Stateless vs. stateful**: A ``hk.Module`` instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. -* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager). +* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager). .. codediff:: @@ -449,7 +449,7 @@ Let's start with an example: Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. In both cases, we use the library's ``scan`` call to run the ``RNNCell`` over the input sequence. -The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan`__, wheras in Haiku you need to transpose the input and output explicitly. +The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan`__, whereas in Haiku you need to transpose the input and output explicitly. .. codediff:: :title: Haiku, Flax NNX @@ -580,7 +580,7 @@ In Flax, model initialization and calling code are completely decoupled, so we u There are a few other details to explain in the Flax example above: -* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of PRNG state and relies on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside. +* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of the PRNG state and rely on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside. * Here, you split the PRNG keys because ``jax.vmap`` and ``jax.lax.scan`` require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the ``MLP``, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform. @@ -640,7 +640,7 @@ Top-level Haiku functions vs top-level Flax modules In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and -states. It very common to write the top-level "Module" as a function instead. +states. It is very common to write the top-level "Module" as a function instead. The Flax team recommends a more Module-centric approach that uses ``__call__`` to define the forward function. In Flax modules, the parameters and variables can @@ -690,4 +690,6 @@ be set and accessed as normal using regular Python class semantics. model = FooModule(rngs=nnx.Rngs(0)) - _, params, counter = nnx.split(model, nnx.Param, Counter) \ No newline at end of file + _, params, counter = nnx.split(model, nnx.Param, Counter) + +