From e14150418d066d59aec0f0adc38b740721ad6450 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Mon, 28 Oct 2024 09:41:04 -0700 Subject: [PATCH] Add API reference for flax.nnx.nn and improve landing page --- docs_nnx/api_reference/flax.nnx/nn/dtypes.rst | 8 +++ docs_nnx/api_reference/flax.nnx/nn/index.rst | 3 ++ docs_nnx/api_reference/flax.nnx/nn/lora.rst | 15 ++++++ .../api_reference/flax.nnx/nn/recurrent.rst | 32 +++++++++++ docs_nnx/guides/haiku_to_flax.rst | 10 ++-- docs_nnx/index.rst | 6 +-- flax/nnx/nn/recurrent.py | 53 ++++++++++--------- 7 files changed, 93 insertions(+), 34 deletions(-) create mode 100644 docs_nnx/api_reference/flax.nnx/nn/dtypes.rst create mode 100644 docs_nnx/api_reference/flax.nnx/nn/lora.rst create mode 100644 docs_nnx/api_reference/flax.nnx/nn/recurrent.rst diff --git a/docs_nnx/api_reference/flax.nnx/nn/dtypes.rst b/docs_nnx/api_reference/flax.nnx/nn/dtypes.rst new file mode 100644 index 0000000000..74edaf8dfc --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/dtypes.rst @@ -0,0 +1,8 @@ +Dtypes +------------------------ + +.. automodule:: flax.nnx.nn.dtypes +.. currentmodule:: flax.nnx.nn.dtypes + +.. autofunction:: canonicalize_dtype +.. autofunction:: promote_dtype \ No newline at end of file diff --git a/docs_nnx/api_reference/flax.nnx/nn/index.rst b/docs_nnx/api_reference/flax.nnx/nn/index.rst index 4b7600b0f0..e42d58428b 100644 --- a/docs_nnx/api_reference/flax.nnx/nn/index.rst +++ b/docs_nnx/api_reference/flax.nnx/nn/index.rst @@ -9,8 +9,11 @@ See the `NNX page `__ for activations attention + dtypes initializers linear + lora normalization + recurrent stochastic diff --git a/docs_nnx/api_reference/flax.nnx/nn/lora.rst b/docs_nnx/api_reference/flax.nnx/nn/lora.rst new file mode 100644 index 0000000000..43461027cf --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/lora.rst @@ -0,0 +1,15 @@ +LoRA +------------------------ + +NNX LoRA classes. + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. flax_module:: + :module: flax.nnx + :class: LoRA + +.. flax_module:: + :module: flax.nnx + :class: LoRALinear diff --git a/docs_nnx/api_reference/flax.nnx/nn/recurrent.rst b/docs_nnx/api_reference/flax.nnx/nn/recurrent.rst new file mode 100644 index 0000000000..b3270d95dd --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/recurrent.rst @@ -0,0 +1,32 @@ +Recurrent +------------------------ + +.. automodule:: flax.nnx.nn.recurrent +.. currentmodule:: flax.nnx.nn.recurrent + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: LSTMCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: OptimizedLSTMCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: SimpleCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: GRUCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: RNN + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: Bidirectional + + +.. autofunction:: flip_sequences \ No newline at end of file diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst index 4349b53d0b..96d05e66da 100644 --- a/docs_nnx/guides/haiku_to_flax.rst +++ b/docs_nnx/guides/haiku_to_flax.rst @@ -82,7 +82,7 @@ There are two fundamental differences between Haiku and Flax ``Module`` objects: Variable creation -====================== +======================= Next, let's discuss instantiating the model and initializing its parameters @@ -122,7 +122,7 @@ If you want to access Flax model parameters in the stateless, dictionary-like fa assert model.block.linear.kernel.value.shape == (784, 256) Training step and compilation -====================== +======================= Now, let's proceed to writing a training step and compiling it using `JAX just-in-time compilation `__. Below are certain differences between Haiku and Flax NNX approaches. @@ -401,7 +401,7 @@ To call those custom methods: Transformations -=============== +======================= Both Haiku and `Flax transformations `__ provide their own set of transforms that wrap `JAX transforms `__ in a way that they can be used with ``Module`` objects. @@ -488,7 +488,7 @@ The only difference is that Flax ``nnx.scan`` allows you to specify which axis t Scan over layers -================ +======================= Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception. @@ -636,7 +636,7 @@ Now inspect the variable pytree on both sides: Top-level Haiku functions vs top-level Flax modules -================ +======================= In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index 982d73b0a9..87d584a6a3 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -17,9 +17,9 @@ Python objects. Flax NNX is an evolution of the previous `Flax Linen `_ and `Why Flax NNX `_. + Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out `Why Flax NNX `_ for a comparison between Flax NNX and Linen, and our reasoning to make the new API. + + To move your Flax Linen codebase to Flax NNX, get familiarized with the API in `NNX Basics `_ and then start your move following the `evolution guide `_. Features ^^^^^^^^^ diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index e659144afb..44b89ad979 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -742,16 +742,17 @@ def flip_sequences( values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements. - Example: - ```python - inputs = [[1, 0, 0], - [2, 3, 0] - [4, 5, 6]] - lengths = [1, 2, 3] - flip_sequences(inputs, lengths) = [[1, 0, 0], - [3, 2, 0], - [6, 5, 4]] - ``` + Example:: + + >>> from flax.nnx.nn.recurrent import flip_sequences + >>> from jax import numpy as jnp + >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) + >>> lengths = jnp.array([1, 2, 3]) + >>> flip_sequences(inputs, lengths, 1, False) + Array([[1, 0, 0], + [3, 2, 0], + [6, 5, 4]], dtype=int32) + Args: inputs: An array of input IDs [batch_size, seq_length]. @@ -810,27 +811,27 @@ def __call__( class Bidirectional(Module): """Processes the input in both directions and merges the results. - Example usage: + Example usage:: + + >>> from flax import nnx + >>> import jax + >>> import jax.numpy as jnp - ```python - import nnx - import jax - import jax.numpy as jnp + >>> # Define forward and backward RNNs + >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) + >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) - # Define forward and backward RNNs - forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) - backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) + >>> # Create Bidirectional layer + >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) - # Create Bidirectional layer - layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) + >>> # Input data + >>> x = jnp.ones((2, 3, 3)) - # Input data - x = jnp.ones((2, 3, 3)) + >>> # Apply the layer + >>> out = layer(x) + >>> print(out.shape) + (2, 3, 8) - # Apply the layer - out = layer(x) - print(out.shape) - ``` """ forward_rnn: RNNBase