Skip to content

Commit

Permalink
Merge pull request #4338 from IvyZX:docfix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690684746
  • Loading branch information
Flax Authors committed Oct 28, 2024
2 parents 0f044c8 + e141504 commit ab122af
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 34 deletions.
8 changes: 8 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/dtypes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Dtypes
------------------------

.. automodule:: flax.nnx.nn.dtypes
.. currentmodule:: flax.nnx.nn.dtypes

.. autofunction:: canonicalize_dtype
.. autofunction:: promote_dtype
3 changes: 3 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for

activations
attention
dtypes
initializers
linear
lora
normalization
recurrent
stochastic

15 changes: 15 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/lora.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
LoRA
------------------------

NNX LoRA classes.

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. flax_module::
:module: flax.nnx
:class: LoRA

.. flax_module::
:module: flax.nnx
:class: LoRALinear
32 changes: 32 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/recurrent.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Recurrent
------------------------

.. automodule:: flax.nnx.nn.recurrent
.. currentmodule:: flax.nnx.nn.recurrent

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: LSTMCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: OptimizedLSTMCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: SimpleCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: GRUCell

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: RNN

.. flax_module::
:module: flax.nnx.nn.recurrent
:class: Bidirectional


.. autofunction:: flip_sequences
10 changes: 5 additions & 5 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ There are two fundamental differences between Haiku and Flax ``Module`` objects:


Variable creation
======================
=======================

Next, let's discuss instantiating the model and initializing its parameters

Expand Down Expand Up @@ -122,7 +122,7 @@ If you want to access Flax model parameters in the stateless, dictionary-like fa
assert model.block.linear.kernel.value.shape == (784, 256)

Training step and compilation
======================
=======================


Now, let's proceed to writing a training step and compiling it using `JAX just-in-time compilation <https://jax.readthedocs.io/en/latest/jit-compilation.html>`__. Below are certain differences between Haiku and Flax NNX approaches.
Expand Down Expand Up @@ -401,7 +401,7 @@ To call those custom methods:


Transformations
===============
=======================

Both Haiku and `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ provide their own set of transforms that wrap `JAX transforms <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ in a way that they can be used with ``Module`` objects.

Expand Down Expand Up @@ -488,7 +488,7 @@ The only difference is that Flax ``nnx.scan`` allows you to specify which axis t


Scan over layers
================
=======================

Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception.

Expand Down Expand Up @@ -636,7 +636,7 @@ Now inspect the variable pytree on both sides:
Top-level Haiku functions vs top-level Flax modules
================
=======================

In Haiku, it is possible to write the entire model as a single function by using
the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ Python objects. Flax NNX is an evolution of the previous `Flax Linen <https://fl
API, and it took years of experience to bring a simpler and more user-friendly API.

.. note::
Flax Linen API is not going to be deprecated in the near future as most of Flax users still
rely on this API. However, new users are encouraged to use Flax NNX.
For existing Flax Linen users planning to move to Flax NNX, check out the `evolution guide <guides/linen_to_nnx.html>`_ and `Why Flax NNX <why.html>`_.
Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out `Why Flax NNX <why.html>`_ for a comparison between Flax NNX and Linen, and our reasoning to make the new API.

To move your Flax Linen codebase to Flax NNX, get familiarized with the API in `NNX Basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`_ and then start your move following the `evolution guide <guides/linen_to_nnx.html>`_.

Features
^^^^^^^^^
Expand Down
53 changes: 27 additions & 26 deletions flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,16 +742,17 @@ def flip_sequences(
values for those sequences that were padded. This function keeps the padding
at the end, while flipping the rest of the elements.
Example:
```python
inputs = [[1, 0, 0],
[2, 3, 0]
[4, 5, 6]]
lengths = [1, 2, 3]
flip_sequences(inputs, lengths) = [[1, 0, 0],
[3, 2, 0],
[6, 5, 4]]
```
Example::
>>> from flax.nnx.nn.recurrent import flip_sequences
>>> from jax import numpy as jnp
>>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
>>> lengths = jnp.array([1, 2, 3])
>>> flip_sequences(inputs, lengths, 1, False)
Array([[1, 0, 0],
[3, 2, 0],
[6, 5, 4]], dtype=int32)
Args:
inputs: An array of input IDs <int>[batch_size, seq_length].
Expand Down Expand Up @@ -810,27 +811,27 @@ def __call__(
class Bidirectional(Module):
"""Processes the input in both directions and merges the results.
Example usage:
Example usage::
>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
```python
import nnx
import jax
import jax.numpy as jnp
>>> # Define forward and backward RNNs
>>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
# Define forward and backward RNNs
forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> # Create Bidirectional layer
>>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)
# Create Bidirectional layer
layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)
>>> # Input data
>>> x = jnp.ones((2, 3, 3))
# Input data
x = jnp.ones((2, 3, 3))
>>> # Apply the layer
>>> out = layer(x)
>>> print(out.shape)
(2, 3, 8)
# Apply the layer
out = layer(x)
print(out.shape)
```
"""

forward_rnn: RNNBase
Expand Down

0 comments on commit ab122af

Please sign in to comment.