Skip to content

Commit

Permalink
Merge branch 'main' into upgrade-haiku-linen-to-nnx
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Oct 28, 2024
2 parents 4ed2ba8 + ab122af commit 84fa316
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 32 deletions.
8 changes: 8 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/dtypes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Dtypes
------------------------

.. automodule:: flax.nnx.nn.dtypes
.. currentmodule:: flax.nnx.nn.dtypes

.. autofunction:: canonicalize_dtype
.. autofunction:: promote_dtype
3 changes: 3 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for

activations
attention
dtypes
initializers
linear
lora
normalization
recurrent
stochastic

15 changes: 15 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/lora.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
LoRA
------------------------

NNX LoRA classes.

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. flax_module::
:module: flax.nnx
:class: LoRA

.. flax_module::
:module: flax.nnx
:class: LoRALinear
32 changes: 32 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/recurrent.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Recurrent
------------------------

.. automodule:: flax.nnx.nn.recurrent
.. currentmodule:: flax.nnx.nn.recurrent

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: LSTMCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: OptimizedLSTMCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: SimpleCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: GRUCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: RNN

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: Bidirectional


.. autofunction:: flip_sequences
6 changes: 3 additions & 3 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ To call those custom methods:


Transformations
===============
=======================

Both Haiku and `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ provide their own set of transforms that wrap `JAX transforms <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ in a way that they can be used with ``Module`` objects.

Expand Down Expand Up @@ -497,7 +497,7 @@ The only difference is that Flax ``nnx.scan`` allows you to specify which axis t


Scan over layers
================
=======================

Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception.

Expand Down Expand Up @@ -645,7 +645,7 @@ Now inspect the variable pytree on both sides:
Top-level Haiku functions vs top-level Flax modules
================
=======================

In Haiku, it is possible to write the entire model as a single function by using
the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ Python objects. Flax NNX is an evolution of the previous `Flax Linen <https://fl
API, and it took years of experience to bring a simpler and more user-friendly API.

.. note::
Flax Linen API is not going to be deprecated in the near future as most of Flax users still
rely on this API. However, new users are encouraged to use Flax NNX.
For existing Flax Linen users planning to move to Flax NNX, check out the `evolution guide <guides/linen_to_nnx.html>`_ and `Why Flax NNX <why.html>`_.
Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out `Why Flax NNX <why.html>`_ for a comparison between Flax NNX and Linen, and our reasoning to make the new API.

To move your Flax Linen codebase to Flax NNX, get familiarized with the API in `NNX Basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`_ and then start your move following the `evolution guide <guides/linen_to_nnx.html>`_.

Features
^^^^^^^^^
Expand Down
53 changes: 27 additions & 26 deletions flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,16 +742,17 @@ def flip_sequences(
values for those sequences that were padded. This function keeps the padding
at the end, while flipping the rest of the elements.
Example:
```python
inputs = [[1, 0, 0],
[2, 3, 0]
[4, 5, 6]]
lengths = [1, 2, 3]
flip_sequences(inputs, lengths) = [[1, 0, 0],
[3, 2, 0],
[6, 5, 4]]
```
Example::
>>> from flax.nnx.nn.recurrent import flip_sequences
>>> from jax import numpy as jnp
>>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
>>> lengths = jnp.array([1, 2, 3])
>>> flip_sequences(inputs, lengths, 1, False)
Array([[1, 0, 0],
[3, 2, 0],
[6, 5, 4]], dtype=int32)
Args:
inputs: An array of input IDs <int>[batch_size, seq_length].
Expand Down Expand Up @@ -810,27 +811,27 @@ def __call__(
class Bidirectional(Module):
"""Processes the input in both directions and merges the results.
Example usage:
Example usage::
>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
```python
import nnx
import jax
import jax.numpy as jnp
>>> # Define forward and backward RNNs
>>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
# Define forward and backward RNNs
forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> # Create Bidirectional layer
>>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)
# Create Bidirectional layer
layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)
>>> # Input data
>>> x = jnp.ones((2, 3, 3))
# Input data
x = jnp.ones((2, 3, 3))
>>> # Apply the layer
>>> out = layer(x)
>>> print(out.shape)
(2, 3, 8)
# Apply the layer
out = layer(x)
print(out.shape)
```
"""

forward_rnn: RNNBase
Expand Down

0 comments on commit 84fa316

Please sign in to comment.