diff --git a/docs_nnx/guides/transforms_tree.ipynb b/docs_nnx/guides/transforms_tree.ipynb new file mode 100644 index 000000000..0977c7ac9 --- /dev/null +++ b/docs_nnx/guides/transforms_tree.ipynb @@ -0,0 +1,571 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "15ed01a7", + "metadata": {}, + "source": [ + "# Transforms\n", + "\n", + "NNX transforms (`nnx.jit`, `nnx.grad`, `nnx.vmap`, `nnx.scan`, ...) are thin wrappers over JAX transforms that provide the same APIs. Their main feature is **automatic state propagation**: input `Variable`'s state is tracked and automatically updated. Here is a sketch of how they work:\n", + "\n", + "```python\n", + "def transform_wrapper(*args):\n", + " if graph: args = to_tree(args)\n", + " check_no_aliases(args=args)\n", + " \n", + " @jax_transform\n", + " def transformed_f(*args):\n", + " updates, snapshot = updates_and_snapshot(args)\n", + " if graph: args = from_tree(args)\n", + " out = f(*args)\n", + " if graph: out = to_tree(out)\n", + " check_no_aliases(args=updates, out=out)\n", + " updates = mask_variable_updates(updates, snapshot)\n", + " return out, updates\n", + " \n", + " out, updates = transformed_f(*args)\n", + " apply_variable_updates(args, updates)\n", + " if graph: out = from_tree(out)\n", + " return out\n", + "```\n", + "\n", + "The transformed function tracks input Variable `updates`, applies `f`, and masks Variable updates (no updates for Variables that didn’t change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus the Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs theres some back forth conversion between object and tree representations at various points." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5a5678db", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import nnx\n", + "import optax\n", + "\n", + "nnx.set_graph_mode(False)\n", + "nnx.set_graph_updates(False)\n", + "jax.config.update(\"jax_num_cpu_devices\", 8)" + ] + }, + { + "cell_type": "markdown", + "id": "56260468", + "metadata": {}, + "source": [ + "## Model definition\n", + "Throughout this guide we'll use a simple `Linear` layer and show how to use it with various transforms. This layer includes:\n", + "- A weight matrix (`w: Param`).\n", + "- A call counter (`count: Count`) — a custom `Variable` type with non-differentiable state.\n", + "- An `rngs` argument in `__call__` to add noise." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "63a98bf5", + "metadata": {}, + "outputs": [], + "source": [ + "class Count(nnx.Variable): pass\n", + "\n", + "class Linear(nnx.Pytree):\n", + " def __init__(self, din, dout, *, rngs):\n", + " self.din, self.dout = din, dout\n", + " self.w = nnx.Param(rngs.normal((din, dout)))\n", + " self.count = Count(jnp.array(0))\n", + "\n", + " def __call__(self, x: jax.Array, *, rngs: nnx.Rngs):\n", + " self.count[...] += 1\n", + " y = x @ self.w\n", + " return y + rngs.normal(y.shape) * 0.1 # noise" + ] + }, + { + "cell_type": "markdown", + "id": "3b7774e0", + "metadata": {}, + "source": [ + "## jit — forward pass\n", + "\n", + "`nnx.jit` compiles and caches the function just like `jax.jit`. Variable updates made\n", + "inside the function are automatically propagated back." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "88d4caf7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5)\n", + "model.count[...] = 1\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "rngs = nnx.Rngs(1)\n", + "x = jnp.ones((3, 2))\n", + "\n", + "@nnx.jit\n", + "def forward(model, x, rngs):\n", + " return model(x, rngs=rngs)\n", + "\n", + "y = forward(model, x, rngs)\n", + "print(f'{y.shape = }')\n", + "print(f'{model.count[...] = !s}') # called once" + ] + }, + { + "cell_type": "markdown", + "id": "76258e93", + "metadata": {}, + "source": [ + "## jit + grad — training step\n", + "\n", + "`nnx.grad` differentiates with respect to `nnx.Param` variables by default, treating all other state as non-differentiable. The `wrt` argument accepts any [Filter](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to select which Variable types to differentiate. It handles `split`/`merge`/`clone` internally, so you only need to write the loss function.\n", + "\n", + "`nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and provides a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "169ae60d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = 1\n", + "optimizer.step[...] = 1\n" + ] + } + ], + "source": [ + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "rngs = nnx.Rngs(1)\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@nnx.jit\n", + "def train_step(model, optimizer, x, y, rngs):\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + " def loss_fn(params, nondiff, rngs):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " return jnp.mean((model(x, rngs=rngs) - y) ** 2)\n", + "\n", + " grads = nnx.grad(loss_fn)(params, nondiff, rngs)\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y, rngs)\n", + "\n", + "print(f'{model.count[...] = !s}') # called once\n", + "print(f'{optimizer.step[...] = !s}') # one optimizer step" + ] + }, + { + "cell_type": "markdown", + "id": "11b38885", + "metadata": {}, + "source": [ + "## vmap — vectorized forward pass\n", + "\n", + "`nnx.vmap` vectorizes a function over an axis dimension. NNX objects participate in\n", + "`in_axes` / `out_axes` just like any other pytree. Here we broadcast the model and `rngs`\n", + "(via `None`) and vectorize only the data across a batch of inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "21d3dc68", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (10, 1, 3)\n", + "model.count[...] = 1\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(1)\n", + "model = Linear(3, 3, rngs=rngs)\n", + "x = jnp.ones((1, 3, 10))\n", + "\n", + "@nnx.vmap(in_axes=(None, 2, None), out_axes=0)\n", + "def batched_forward(model, x, rngs): # model & rngs broadcast, x vectorized\n", + " return model(x, rngs=rngs) \n", + "\n", + "y = batched_forward(model, x, rngs)\n", + "print(f'{y.shape = !s}') # (1, 5, 10)\n", + "print(f'{model.count[...] = !s}') # called once (broadcast)" + ] + }, + { + "cell_type": "markdown", + "id": "60f29b83", + "metadata": {}, + "source": [ + "Because the model is passed with `in_axes=None`, it is broadcast — the same weights\n", + "are shared across all vectorized inputs. The same applies to `rngs`, so every input\n", + "sees identical noise." + ] + }, + { + "cell_type": "markdown", + "id": "728a8dc6", + "metadata": {}, + "source": [ + "## vmap + scan — scan over layers\n", + "\n", + "A common pattern is to stack many identical layers and apply them sequentially.\n", + "We use `nnx.vmap` to initialize a stack of layers, then `nnx.scan` to iterate\n", + "over them. The hidden state `x` is passed as a **Carry**, while the layer stack and the split `rngs` are the **scanned** over axis 0. The new state `x` is returned as a Carry for the next iteration." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0a0ecc4f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "stack.w.shape = (5, 3, 3)\n", + "y.shape = (1, 3)\n", + "stack.count[...] = [1 1 1 1 1]\n", + "rngs.default.count[...] = 2\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(0)\n", + "# --- initialize a stack of layers with vmap ---\n", + "@nnx.vmap(in_axes=0, out_axes=0)\n", + "def create_stack(rngs):\n", + " return Linear(3, 3, rngs=rngs)\n", + "\n", + "stack = create_stack(rngs.split(5))\n", + "print(f'{stack.w.shape = }') # (5, 3, 3) — one weight per layer\n", + "\n", + "# --- scan over the layer stack ---\n", + "@nnx.scan(in_axes=(nnx.Carry, 0, 0), out_axes=nnx.Carry)\n", + "def apply_stack(x, layer, rngs):\n", + " x = layer(x, rngs=rngs)\n", + " return x\n", + "\n", + "x = jnp.ones((1, 3))\n", + "y = apply_stack(x, stack, rngs.split(5))\n", + "\n", + "print(f'{y.shape = !s}') # (1, 3) — final output after all layers\n", + "print(f'{stack.count[...] = !s}') # each layer called once\n", + "print(f'{rngs.default.count[...] = !s}') # rngs used 2 times (one per split)" + ] + }, + { + "cell_type": "markdown", + "id": "8bdc1214", + "metadata": {}, + "source": [ + "Updates to `count` Variables are propagated out automatically." + ] + }, + { + "cell_type": "markdown", + "id": "1e6fac5a", + "metadata": {}, + "source": [ + "## Graph Mode\n", + "\n", + "Setting `graph=True` on any NNX transform allows passing NNX objects with shared references — that is, inputs that form a graph rather than a strict tree. By default transforms require each leaf to appear exactly once; passing the same `Variable` in two arguments violates that constraint and raises an error. With `graph=True`, the transform detects shared `Variable`s, handles them correctly, and propagates updates back to the original object. Sharing not only applies to `Variable`s, but also to any `nnx.Pytree` which is the base type for `Module`, `Optimizer`, `Metric`, etc.\n", + "\n", + "The example below shares a single `Variable` between two arguments:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8d8f7093", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "v[...] = 1\n" + ] + } + ], + "source": [ + "@nnx.jit(graph=True)\n", + "def f(v1, v2):\n", + " assert v1 is v2 # relative identities are preserved in graph mode\n", + " v1[...] += 1\n", + "\n", + "v = nnx.Variable(jnp.array(0))\n", + "f(v, v)\n", + "\n", + "print(f'{v[...] = !s}') # v is updated in-place, so should be 1" + ] + }, + { + "cell_type": "markdown", + "id": "a8577497", + "metadata": {}, + "source": [ + "Graph mode does have one important limitation: aliased `Variable`s must be treated consistently across all arguments. For example, if the same `Variable` is passed to two arguments that have different `in_axes`, the transform cannot resolve the conflict and will raise an error:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2f8596bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error: Inconsistent aliasing detected. The following nodes have different prefixes:\n", + "Node: Variable\n", + " 0/0: None\n", + " 0/1: 0\n" + ] + } + ], + "source": [ + "@nnx.vmap(in_axes=(None, 0), graph=True)\n", + "def f(v1, v2):\n", + " pass\n", + "\n", + "v = nnx.Variable(jnp.array(0))\n", + "\n", + "try:\n", + " f(v, v)\n", + "except Exception as e:\n", + " print(f'Error: {e}')" + ] + }, + { + "cell_type": "markdown", + "id": "b71f8cfb", + "metadata": {}, + "source": [ + "This is roughly saying that the same Variable (`v`) received `in_axes` of `None` on the first argument and `0` on the second argument, which is a conflict." + ] + }, + { + "cell_type": "markdown", + "id": "0a5b5130", + "metadata": {}, + "source": [ + "## Legacy: Graph Updates and Prefix Filters\n", + "\n", + "NNX transforms also supports a legacy **graph updates** mode which requires setting `graph=True` and `graph_updates=True` on each transform. In this mode updates to the graph objects (e.g. Modules) are also tracked and propagated. In this mode **prefix filters** like `StateAxes`, `DiffState`, `StateSharding` can be used to specify how graph substates are treated by transforms. For convenience the legacy behavior of the transforms can used via the `nnx.compat` module, this simply sets the `graph` and `graph_updates` to `True` on each transform.\n", + "\n", + "In this section we will explain how to use prefix filters for users that still rely on the behavior.\n", + "\n", + "### StateAxes\n", + "\n", + "`nnx.StateAxes` lets you specify substate axis behavior inside `nnx.vmap`, `nnx.scan`, and `nnx.pmap`. It maps [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) like Variable types or path predicates to axis indices or `None` (broadcast).\n", + "\n", + "For example, you might want to vectorize the `Param` weights on axis 0 but broadcast\n", + "the `Count` state:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f9463828", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 10)\n", + "weights.count[...] = 1\n" + ] + } + ], + "source": [ + "class Weights(nnx.Module):\n", + " def __init__(self, kernel, count):\n", + " self.kernel = nnx.Param(kernel)\n", + " self.count = Count(count)\n", + "\n", + "rngs = nnx.Rngs(0)\n", + "weights = Weights(\n", + " kernel=rngs.uniform((10, 2, 3)),\n", + " count=jnp.array(0), # single scalar, not vectorized\n", + ")\n", + "x = rngs.normal((10, 2))\n", + "\n", + "state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count\n", + "\n", + "@nnx.compat.vmap(in_axes=(state_axes, 0), out_axes=1)\n", + "def forward(weights, x):\n", + " weights.count[...] += 1\n", + " return x @ weights.kernel\n", + "\n", + "y = forward(weights, x)\n", + "print(f'{y.shape = !s}')\n", + "print(f'{weights.count[...] = !s}')" + ] + }, + { + "cell_type": "markdown", + "id": "1f5ab45a", + "metadata": {}, + "source": [ + "### DiffState\n", + "\n", + "`nnx.DiffState` lets you control which sub-state of an argument participates in\n", + "differentiation with `nnx.grad`. It wraps an argument index and a filter:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "db9c20cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grads_m1: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254m'kernel'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m})\u001b[0m\n", + "grads_m2: \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254m'bias'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m})\u001b[0m\n" + ] + } + ], + "source": [ + "m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))\n", + "m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1))\n", + "\n", + "# only differentiate m1's kernel and m2's bias\n", + "@nnx.compat.grad(argnums=(\n", + " nnx.DiffState(0, nnx.PathContains('kernel')),\n", + " nnx.DiffState(1, nnx.PathContains('bias')),\n", + "))\n", + "def loss_fn(m1, m2):\n", + " return jnp.mean(m1.kernel * m2.kernel) + jnp.mean(m1.bias * m2.bias)\n", + "\n", + "grads_m1, grads_m2 = loss_fn(m1, m2)\n", + "print(f'grads_m1: {jax.tree.map(jnp.shape, grads_m1)}')\n", + "print(f'grads_m2: {jax.tree.map(jnp.shape, grads_m2)}')" + ] + }, + { + "cell_type": "markdown", + "id": "c6f97d23", + "metadata": {}, + "source": [ + "Without graph updates, you achieve the same effect using `nnx.split` to separate the parts\n", + "you want to differentiate, then pass them to `grad` directly.\n", + "\n", + "### StateSharding\n", + "\n", + "`nnx.StateSharding` maps Variable types to JAX shardings for use with `nnx.jit`. It\n", + "has the same structure as `StateAxes` but values are sharding specs instead of axis\n", + "indices:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b2465669", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (16, 16)\n", + "weights.count[...] = 1\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(1)\n", + "mesh = jax.make_mesh((8,), ('devices',))\n", + "\n", + "def sharding(*args):\n", + " return jax.sharding.NamedSharding(mesh, jax.P(*args))\n", + "\n", + "# Create weights outside mesh context so arrays are uncommitted\n", + "weights = Weights(\n", + " kernel=rngs.uniform((16, 16)),\n", + " count=jnp.array(0),\n", + ")\n", + "x = jnp.ones((16, 16))\n", + "\n", + "# Define sharding for different Variable types\n", + "state_sharding = nnx.StateSharding({\n", + " nnx.Param: sharding(None, 'devices'), # shard Param on second axis\n", + " Count: sharding(), # replicate Count\n", + "})\n", + "\n", + "@nnx.compat.jit(in_shardings=(state_sharding, sharding('devices')))\n", + "def forward(weights, x):\n", + " weights.count[...] += 1\n", + " return x @ weights.kernel\n", + "\n", + "y = forward(weights, x)\n", + "print(f'{y.shape = }')\n", + "print(f'{weights.count[...] = !s}')" + ] + }, + { + "cell_type": "markdown", + "id": "1a71ae23", + "metadata": {}, + "source": [ + "Without graph updates, you can use standard pytree-based `in_shardings` / `out_shardings` with `nnx.jit` or `jax.jit` directly." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/guides/transforms_tree.md b/docs_nnx/guides/transforms_tree.md new file mode 100644 index 000000000..6fbc1e666 --- /dev/null +++ b/docs_nnx/guides/transforms_tree.md @@ -0,0 +1,315 @@ +--- +jupytext: + cell_metadata_filter: -all + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Transforms + +NNX transforms (`nnx.jit`, `nnx.grad`, `nnx.vmap`, `nnx.scan`, ...) are thin wrappers over JAX transforms that provide the same APIs. Their main feature is **automatic state propagation**: input `Variable`'s state is tracked and automatically updated. Here is a sketch of how they work: + +```python +def transform_wrapper(*args): + if graph: args = to_tree(args) + check_no_aliases(args=args) + + @jax_transform + def transformed_f(*args): + updates, snapshot = updates_and_snapshot(args) + if graph: args = from_tree(args) + out = f(*args) + if graph: out = to_tree(out) + check_no_aliases(args=updates, out=out) + updates = mask_variable_updates(updates, snapshot) + return out, updates + + out, updates = transformed_f(*args) + apply_variable_updates(args, updates) + if graph: out = from_tree(out) + return out +``` + +The transformed function tracks input Variable `updates`, applies `f`, and masks Variable updates (no updates for Variables that didn’t change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus the Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs theres some back forth conversion between object and tree representations at various points. + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp +from flax import nnx +import optax + +nnx.set_graph_mode(False) +nnx.set_graph_updates(False) +jax.config.update("jax_num_cpu_devices", 8) +``` + +## Model definition +Throughout this guide we'll use a simple `Linear` layer and show how to use it with various transforms. This layer includes: +- A weight matrix (`w: Param`). +- A call counter (`count: Count`) — a custom `Variable` type with non-differentiable state. +- An `rngs` argument in `__call__` to add noise. + +```{code-cell} ipython3 +class Count(nnx.Variable): pass + +class Linear(nnx.Pytree): + def __init__(self, din, dout, *, rngs): + self.din, self.dout = din, dout + self.w = nnx.Param(rngs.normal((din, dout))) + self.count = Count(jnp.array(0)) + + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs): + self.count[...] += 1 + y = x @ self.w + return y + rngs.normal(y.shape) * 0.1 # noise +``` + +## jit — forward pass + +`nnx.jit` compiles and caches the function just like `jax.jit`. Variable updates made +inside the function are automatically propagated back. + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) +rngs = nnx.Rngs(1) +x = jnp.ones((3, 2)) + +@nnx.jit +def forward(model, x, rngs): + return model(x, rngs=rngs) + +y = forward(model, x, rngs) +print(f'{y.shape = }') +print(f'{model.count[...] = !s}') # called once +``` + +## jit + grad — training step + +`nnx.grad` differentiates with respect to `nnx.Param` variables by default, treating all other state as non-differentiable. The `wrt` argument accepts any [Filter](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to select which Variable types to differentiate. It handles `split`/`merge`/`clone` internally, so you only need to write the loss function. + +`nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and provides a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters. + +```{code-cell} ipython3 +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +rngs = nnx.Rngs(1) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@nnx.jit +def train_step(model, optimizer, x, y, rngs): + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + def loss_fn(params, nondiff, rngs): + model = nnx.merge(graphdef, params, nondiff) + return jnp.mean((model(x, rngs=rngs) - y) ** 2) + + grads = nnx.grad(loss_fn)(params, nondiff, rngs) + optimizer.update(model, grads) + +train_step(model, optimizer, x, y, rngs) + +print(f'{model.count[...] = !s}') # called once +print(f'{optimizer.step[...] = !s}') # one optimizer step +``` + +## vmap — vectorized forward pass + +`nnx.vmap` vectorizes a function over an axis dimension. NNX objects participate in +`in_axes` / `out_axes` just like any other pytree. Here we broadcast the model and `rngs` +(via `None`) and vectorize only the data across a batch of inputs. + +```{code-cell} ipython3 +rngs = nnx.Rngs(1) +model = Linear(3, 3, rngs=rngs) +x = jnp.ones((1, 3, 10)) + +@nnx.vmap(in_axes=(None, 2, None), out_axes=0) +def batched_forward(model, x, rngs): # model & rngs broadcast, x vectorized + return model(x, rngs=rngs) + +y = batched_forward(model, x, rngs) +print(f'{y.shape = !s}') # (1, 5, 10) +print(f'{model.count[...] = !s}') # called once (broadcast) +``` + +Because the model is passed with `in_axes=None`, it is broadcast — the same weights +are shared across all vectorized inputs. The same applies to `rngs`, so every input +sees identical noise. + ++++ + +## vmap + scan — scan over layers + +A common pattern is to stack many identical layers and apply them sequentially. +We use `nnx.vmap` to initialize a stack of layers, then `nnx.scan` to iterate +over them. The hidden state `x` is passed as a **Carry**, while the layer stack and the split `rngs` are the **scanned** over axis 0. The new state `x` is returned as a Carry for the next iteration. + +```{code-cell} ipython3 +rngs = nnx.Rngs(0) +# --- initialize a stack of layers with vmap --- +@nnx.vmap(in_axes=0, out_axes=0) +def create_stack(rngs): + return Linear(3, 3, rngs=rngs) + +stack = create_stack(rngs.split(5)) +print(f'{stack.w.shape = }') # (5, 3, 3) — one weight per layer + +# --- scan over the layer stack --- +@nnx.scan(in_axes=(nnx.Carry, 0, 0), out_axes=nnx.Carry) +def apply_stack(x, layer, rngs): + x = layer(x, rngs=rngs) + return x + +x = jnp.ones((1, 3)) +y = apply_stack(x, stack, rngs.split(5)) + +print(f'{y.shape = !s}') # (1, 3) — final output after all layers +print(f'{stack.count[...] = !s}') # each layer called once +print(f'{rngs.default.count[...] = !s}') # rngs used 2 times (one per split) +``` + +Updates to `count` Variables are propagated out automatically. + ++++ + +## Graph Mode + +Setting `graph=True` on any NNX transform allows passing NNX objects with shared references — that is, inputs that form a graph rather than a strict tree. By default transforms require each leaf to appear exactly once; passing the same `Variable` in two arguments violates that constraint and raises an error. With `graph=True`, the transform detects shared `Variable`s, handles them correctly, and propagates updates back to the original object. Sharing not only applies to `Variable`s, but also to any `nnx.Pytree` which is the base type for `Module`, `Optimizer`, `Metric`, etc. + +The example below shares a single `Variable` between two arguments: + +```{code-cell} ipython3 +@nnx.jit(graph=True) +def f(v1, v2): + assert v1 is v2 # relative identities are preserved in graph mode + v1[...] += 1 + +v = nnx.Variable(jnp.array(0)) +f(v, v) + +print(f'{v[...] = !s}') # v is updated in-place, so should be 1 +``` + +Graph mode does have one important limitation: aliased `Variable`s must be treated consistently across all arguments. For example, if the same `Variable` is passed to two arguments that have different `in_axes`, the transform cannot resolve the conflict and will raise an error: + +```{code-cell} ipython3 +@nnx.vmap(in_axes=(None, 0), graph=True) +def f(v1, v2): + pass + +v = nnx.Variable(jnp.array(0)) + +try: + f(v, v) +except Exception as e: + print(f'Error: {e}') +``` + +This is roughly saying that the same Variable (`v`) received `in_axes` of `None` on the first argument and `0` on the second argument, which is a conflict. + ++++ + +## Legacy: Graph Updates and Prefix Filters + +NNX transforms also supports a legacy **graph updates** mode which requires setting `graph=True` and `graph_updates=True` on each transform. In this mode updates to the graph objects (e.g. Modules) are also tracked and propagated. In this mode **prefix filters** like `StateAxes`, `DiffState`, `StateSharding` can be used to specify how graph substates are treated by transforms. For convenience the legacy behavior of the transforms can used via the `nnx.compat` module, this simply sets the `graph` and `graph_updates` to `True` on each transform. + +In this section we will explain how to use prefix filters for users that still rely on the behavior. + +### StateAxes + +`nnx.StateAxes` lets you specify substate axis behavior inside `nnx.vmap`, `nnx.scan`, and `nnx.pmap`. It maps [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) like Variable types or path predicates to axis indices or `None` (broadcast). + +For example, you might want to vectorize the `Param` weights on axis 0 but broadcast +the `Count` state: + +```{code-cell} ipython3 +class Weights(nnx.Module): + def __init__(self, kernel, count): + self.kernel = nnx.Param(kernel) + self.count = Count(count) + +rngs = nnx.Rngs(0) +weights = Weights( + kernel=rngs.uniform((10, 2, 3)), + count=jnp.array(0), # single scalar, not vectorized +) +x = rngs.normal((10, 2)) + +state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count + +@nnx.compat.vmap(in_axes=(state_axes, 0), out_axes=1) +def forward(weights, x): + weights.count[...] += 1 + return x @ weights.kernel + +y = forward(weights, x) +print(f'{y.shape = !s}') +print(f'{weights.count[...] = !s}') +``` + +### DiffState + +`nnx.DiffState` lets you control which sub-state of an argument participates in +differentiation with `nnx.grad`. It wraps an argument index and a filter: + +```{code-cell} ipython3 +m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) +m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) + +# only differentiate m1's kernel and m2's bias +@nnx.compat.grad(argnums=( + nnx.DiffState(0, nnx.PathContains('kernel')), + nnx.DiffState(1, nnx.PathContains('bias')), +)) +def loss_fn(m1, m2): + return jnp.mean(m1.kernel * m2.kernel) + jnp.mean(m1.bias * m2.bias) + +grads_m1, grads_m2 = loss_fn(m1, m2) +print(f'grads_m1: {jax.tree.map(jnp.shape, grads_m1)}') +print(f'grads_m2: {jax.tree.map(jnp.shape, grads_m2)}') +``` + +Without graph updates, you achieve the same effect using `nnx.split` to separate the parts +you want to differentiate, then pass them to `grad` directly. + +### StateSharding + +`nnx.StateSharding` maps Variable types to JAX shardings for use with `nnx.jit`. It +has the same structure as `StateAxes` but values are sharding specs instead of axis +indices: + +```{code-cell} ipython3 +rngs = nnx.Rngs(1) +mesh = jax.make_mesh((8,), ('devices',)) + +def sharding(*args): + return jax.sharding.NamedSharding(mesh, jax.P(*args)) + +# Create weights outside mesh context so arrays are uncommitted +weights = Weights( + kernel=rngs.uniform((16, 16)), + count=jnp.array(0), +) +x = jnp.ones((16, 16)) + +# Define sharding for different Variable types +state_sharding = nnx.StateSharding({ + nnx.Param: sharding(None, 'devices'), # shard Param on second axis + Count: sharding(), # replicate Count +}) + +@nnx.compat.jit(in_shardings=(state_sharding, sharding('devices'))) +def forward(weights, x): + weights.count[...] += 1 + return x @ weights.kernel + +y = forward(weights, x) +print(f'{y.shape = }') +print(f'{weights.count[...] = !s}') +``` + +Without graph updates, you can use standard pytree-based `in_shardings` / `out_shardings` with `nnx.jit` or `jax.jit` directly. diff --git a/docs_nnx/nnx_basics_tree.ipynb b/docs_nnx/nnx_basics_tree.ipynb new file mode 100644 index 000000000..ee63ff443 --- /dev/null +++ b/docs_nnx/nnx_basics_tree.ipynb @@ -0,0 +1,675 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NNX Basics\n", + "\n", + "NNX is a Neural Networks library for JAX. NNX provides the tools to structure modeling code as [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) so it can work with transforms, `jax.tree.*` utilities, and all standard JAX APIs. This guide covers the core concepts you need to get started." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "nnx.graphlib.set_graph_mode(False)\n", + "nnx.graphlib.set_graph_updates(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "NNX's main build blocks are:\n", + "\n", + "- **`nnx.Pytree`**: Base class for pytree-compatible objects. Defines the tree structure of your model.\n", + "- **`nnx.Variable`**: Wraps array data and tracks mutable state. Subclasses like `nnx.Param` categorize different kinds of state.\n", + "- **State APIs** (`nnx.{state, split, merge, update}`): Extract, partition, reconstruct, and apply state updates.\n", + "- **NNX Transforms** (`nnx.{jit, grad, scan, ...}`): Thin wrappers over JAX transforms that automate state propagation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pytrees and Variables\n", + "\n", + "`nnx.Pytree` and `nnx.Variable` are two orthogonal systems. **Pytrees** define the structure of your model as a JAX-compatible tree. **Variables** wrap array data and enable expressing state updates via in-place mutation. \n", + "\n", + "`Pytree`s are python objects that define its tree structure dynamically through its attributes, these are split into two categories: **Static attributes** (e.g. `int`, `str`) are embedded in the tree structure definition and are not traced by JAX. **Data attributes** (e.g. `nnx.Variable`, `jax.Array`) are the leaves of the tree and are traced by JAX. For more details see the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html).\n", + "\n", + "Here's a typical layer definition:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class Count(nnx.Variable): pass # custom Variable types\n", + "\n", + "class Linear(nnx.Pytree):\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din, self.dout = din, dout # static attributes\n", + " self.w = nnx.Param(rngs.uniform((din, dout))) # data attribute\n", + " self.count = Count(jnp.array(0)) # data attribute\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " self.count[...] += 1 # inplace state updates\n", + " return x @ self.w # Variable are Array-like\n", + "\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "nnx.display(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Note:** Most user code uses `nnx.Module`, which is a subclass of `nnx.Pytree` with additional features such as sopport for metric reporting.\n", + "\n", + "As we can see above, Variables are array-like; they support arithmetic operators, indexing, and can be used directly in JAX expressions. You can update their value in-place using `variable[...] = new_value`. Since NNX Pytrees are standard JAX pytrees, you can use `jax.tree.*` functions directly on them:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5), model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "\n", + "model.w sum: 4.1854\n", + "doubled.w sum: 8.3709\n", + "\n", + "Pytree leaves:\n", + ".count.value: Array(1, dtype=int32, weak_type=True)\n", + ".w.value: Array([[0.8423141 , 0.18237865, 0.2271781 , 0.12072563, 0.19181347],\n", + " [0.722015 , 0.7654456 , 0.15254045, 0.9517063 , 0.02931046]], dtype=float32)\n" + ] + } + ], + "source": [ + "x = jnp.ones((3, 2))\n", + "y = model(x)\n", + "print(f'{y.shape = }, {model.count[...] = }')\n", + "\n", + "# jax.tree.map works directly on NNX Pytrees\n", + "doubled_model = jax.tree.map(lambda x: x * 2, model)\n", + "print(f'\\nmodel.w sum: {model.w.sum():.4f}')\n", + "print(f'doubled.w sum: {doubled_model.w.sum():.4f}')\n", + "\n", + "# jax.tree.leaves_with_path shows the full tree structure\n", + "print('\\nPytree leaves:')\n", + "for path, value in jax.tree.leaves_with_path(model):\n", + " print(f'{jax.tree_util.keystr(path)}: {value!r}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here `jax.tree.map` was first used create a new model with each leaf Array doubled, and then `jax.tree.flatten_with_path` was used to show how JAX sees the tree structure. Notice that because Variables are also JAX pytrees containing a single element (their inner value) we see `value` as part of the leaf path." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rngs\n", + "`nnx.Rngs` simplify managing [JAX PRNG state](https://jax.readthedocs.io/en/latest/random-numbers.html). It is itself an `nnx.Pytree` that stores a seed `key` and an incrementing `counter` in `Variable`s internally. By calling it, `Rngs` can produce new PRNG keys:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key1 = Array((), dtype=key) overlaying:\n", + "[1797259609 2579123966]\n", + "key2 = Array((), dtype=key) overlaying:\n", + "[ 928981903 3453687069]\n", + "arr = Array([[ 1.2956359 , 1.3550105 , -0.40960556],\n", + " [-0.77188545, 0.38094172, 0.01888919]], dtype=float32)\n", + "\u001b[38;2;79;201;177mRngs\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # RngState: 2 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mdefault\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngStream\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # RngState: 2 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m,\n", + " \u001b[38;2;156;220;254mkey\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", + " [0 0],\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mcount\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(3, dtype=uint32),\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(0) # seeded with 0\n", + "\n", + "key1 = rngs() # get a raw key\n", + "key2 = rngs() # different key (counter incremented)\n", + "arr = rngs.normal((2, 3)) # draw samples directly\n", + "\n", + "print(f'{key1 = }')\n", + "print(f'{key2 = }')\n", + "print(f'{arr = }')\n", + "print(rngs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we've seen so far, `Rngs` conveniently exposes every `jax.random.*` distribution as a method (e.g. `rngs.uniform(...)`, `rngs.normal(...)`) without requiring the `key` argument and returning different random values every time they are called, this highly simplifies the user experience. In general `Rngs` can hold multiple keys and counters in structures called `RngStream`s, above we see that the `default` stream is being used. For more information check out the [Randomness guide](https://flax.readthedocs.io/en/latest/guides/randomness.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Nested Modules\n", + "\n", + "Pytree subclasses compose naturally, you can assign one as an attribute of another to build nested models. The example below builds a simple `MLP` from two `Linear` layers:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class MLP(nnx.Pytree):\n", + " def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din, self.dmid, self.dout = din, dmid, dout # static attributes\n", + " self.linear1 = Linear(din, dmid, rngs=rngs) # data attribute\n", + " self.linear2 = Linear(dmid, dout, rngs=rngs) # data attribute\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " x = nnx.relu(self.linear1(x))\n", + " return self.linear2(x)\n", + "\n", + "mlp = MLP(2, 16, 5, rngs=nnx.Rngs(0))\n", + "y = mlp(jnp.ones((3, 2)))\n", + "print(f'{y.shape = }')\n", + "\n", + "nnx.display(mlp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because the entire model is a single pytree, all the `jax.tree.*` functions, JAX transforms, and NNX state APIs work on the full nested structure at once. For more info check out the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## JAX Transforms\n", + "\n", + "NNX models can be passed directly to JAX transforms like `jax.jit`. However, JAX transforms create pure functions, meaning that they won't propagate side effects such as Variable state updates back to the caller:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit\n", + "def forward(model, x): # pure function\n", + " y = model(x)\n", + " return y\n", + "\n", + "y = forward(model, x)\n", + "\n", + "print(model.count[...]) # no state update" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here `count` was not updated because inside `jax.jit` new Variable copies are created so any updates inside will not be reflected outside. To propagate updates we can use two NNX helpers. `nnx.state(obj, *filters)` extracts the current state of all Variables in `obj` as a nested `State` dict; you can pass **filters** to select specific Variable types, for example `nnx.state(model, Count)` extracts only `Count` Variables (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for details). `nnx.update(obj, state)` writes a `State` back into the corresponding Variables of `obj`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit\n", + "def forward(model, x):\n", + " y = model(x)\n", + " return y, nnx.state(model, Count) # propagate state\n", + "\n", + "y, updates = forward(model, x)\n", + "nnx.update(model, updates) # apply state updates\n", + "\n", + "print(model.count[...]) # updated successfully" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we could've also chosen to return the entire `model` and replace its reference outside, however the use `nnx.state/update` is preferred as NNX promotes preserving existing Variable references." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training step with JAX transforms\n", + "\n", + "For a full training step we also need to differentiate with respect to some parameters while keeping the rest non-differentiable. `nnx.split` and `nnx.merge` let us partition and reconstruct the model. `nnx.split(obj, *filters)` returns a structure definition (`GraphDef`) followed by one `State` group per filter, where the catch-all filter `...` matches everything not yet matched by a previous filter (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for the full filter language). `nnx.merge(graphdef, *states)` reconstructs a copy of the object from its definition and state groups. We will use these to select the differentiable parameters when passing them to `jax.grad`.\n", + "\n", + "The example below shows a complete training step using raw JAX transforms. `nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and stores its internal state as Variables, providing a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@jax.jit\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params, nondiff):\n", + " nondiff = nnx.clone(nondiff) # refresh trace state\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss, nnx.state(model, Count) # propagate state\n", + "\n", + " grads, updates = jax.grad(loss_fn, has_aux=True)(params, nondiff)\n", + " nnx.update(model, updates)\n", + " optimizer.update(model, grads)\n", + "\n", + " return nnx.state((model, optimizer))\n", + "\n", + "updates = train_step(model, optimizer, x, y)\n", + "nnx.update((model, optimizer), updates)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few things to note. The state of the `model` and `optimizer` is extracted at once by packing them in a tuple (or any pytree), and `nnx.update` accepts the same structure. By default `jax.grad` differentiates with respect to the first positional argument only, `params` in our case. Finally, `nnx.clone` is needed because `jax.grad` passes non differentiable inputs (here `nondiff`) directly without tracing them, so we must manually clone them to refresh the trace state of their Variables - preventing tracer leakage. Omitting `nnx.clone` raises an error." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NNX Transforms\n", + "\n", + "NNX transforms (`nnx.jit`, `nnx.grad`, ...) are thin wrappers over JAX transforms that provide the exact same APIs. Their main feature is **automatic state propagation**: the state of all input Variables is tracked and updated automatically behind the scenes. This removes the need for the `nnx.state/update` boilerplate and the use of `nnx.clone`:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@nnx.jit # automatic state propagation\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params, nondiff):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss\n", + "\n", + " grads = nnx.grad(loss_fn)(params, nondiff)\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that `train_step` doesn't need to return anthing as `nnx.jit` propagates all Variable updates (model parameters, optimizer state, counts) automatically." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Graph Mode\n", + "\n", + "Certain programs are easier to express by sharing references between objets on different parts of a structure, however this is not compatible with JAX's pytree model. If we create a simple model that shares a reference to the same Variable in two different attributes, NNX transforms and most other APIs will raise an error as sharing can result in inconsistencies:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error: Variable at [0][0].b was already seen at [0][0].a. tree-mode jit does not support shared Variable references.\n" + ] + } + ], + "source": [ + "@nnx.dataclass\n", + "class Foo(nnx.Module):\n", + " a: nnx.Param\n", + " b: nnx.Param\n", + "\n", + "p = nnx.Param(jnp.array(1.0))\n", + "model = Foo(p, p) # shared Param\n", + "\n", + "@nnx.jit\n", + "def forward(model, x):\n", + " model.a[...] += 1.0\n", + " return model.a * x + model.b\n", + "\n", + "try:\n", + " forward(model, jnp.array(1.0))\n", + "except ValueError as e:\n", + " print(f'Error: {e}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, at the cost of some python overhead, `graph=True` can be passed to NNX APIs to enable **graph mode**. In graph mode, general graphs structures are allowed as long as they Variables are transformed consistently. We can fix the above example by enabling graph mode in `nnx.jit`:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y = 6.0, model.a[...] = 3.0, model.b[...] = 3.0\n" + ] + } + ], + "source": [ + "@nnx.jit(graph=True)\n", + "def forward(model, x):\n", + " model.a[...] += 1.0\n", + " return model.a * x + model.b\n", + "\n", + "y = forward(model, jnp.array(1.0))\n", + "\n", + "print(f'{y = !s}, {model.a[...] = !s}, {model.b[...] = !s}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hijax (experimental)\n", + "\n", + "JAX's experimental **Hijax** API allows custom mutable types whose state updates propagate automatically through JAX transforms. When enabled via `nnx.var_default(hijax=True)`, plain JAX transforms like `jax.jit` handle state propagation of `Variable`s without any manual `nnx.state` / `nnx.update` calls. As a bonus, in hijax mode Variables can also be passed as captures, further simplifying the loss function:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Count: 1 (4 B), Param: 10 (40 B), Total: 11 (44 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mdin\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m,\n", + " \u001b[38;2;156;220;254mdout\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m5\u001b[0m,\n", + " \u001b[38;2;156;220;254mw\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 10 (40 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m5\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mcount\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mCount\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "with nnx.var_defaults(hijax=True): # enables Hijax Variables\n", + " x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + " model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "print(model) # display Hijax Variables\n", + "\n", + "@jax.jit # automatic state propagation\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss\n", + "\n", + " grads = jax.grad(loss_fn)(nnx.vars_as(params, hijax=False)) # disable hijax for param grads\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a temporary limitation, `jax.grad` does not yet handle mutable Hijax types. We work around this by converting `params` to regular Variables via `nnx.vars_as(params, hijax=False)` before passing them to `grad`. Hijax can also be enabled on a per-Variable basis by passing `hijax=True` to the constructor:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "v[...] = 1\n", + "v[...] = 2\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(1), hijax=True)\n", + "\n", + "@jax.jit\n", + "def inc(v):\n", + " v[...] += 1\n", + "\n", + "print(f'{v[...] = !s}')\n", + "inc(v)\n", + "print(f'{v[...] = !s}')" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs_nnx/nnx_basics_tree.md b/docs_nnx/nnx_basics_tree.md new file mode 100644 index 000000000..2bdc6f577 --- /dev/null +++ b/docs_nnx/nnx_basics_tree.md @@ -0,0 +1,319 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# NNX Basics + +NNX is a Neural Networks library for JAX. NNX provides the tools to structure modeling code as [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) so it can work with transforms, `jax.tree.*` utilities, and all standard JAX APIs. This guide covers the core concepts you need to get started. + +```{code-cell} ipython3 +from flax import nnx +import jax +import jax.numpy as jnp + +nnx.graphlib.set_graph_mode(False) +nnx.graphlib.set_graph_updates(False) +``` + +NNX's main build blocks are: + +- **`nnx.Pytree`**: Base class for pytree-compatible objects. Defines the tree structure of your model. +- **`nnx.Variable`**: Wraps array data and tracks mutable state. Subclasses like `nnx.Param` categorize different kinds of state. +- **State APIs** (`nnx.{state, split, merge, update}`): Extract, partition, reconstruct, and apply state updates. +- **NNX Transforms** (`nnx.{jit, grad, scan, ...}`): Thin wrappers over JAX transforms that automate state propagation. + ++++ + +## Pytrees and Variables + +`nnx.Pytree` and `nnx.Variable` are two orthogonal systems. **Pytrees** define the structure of your model as a JAX-compatible tree. **Variables** wrap array data and enable expressing state updates via in-place mutation. + +`Pytree`s are python objects that define its tree structure dynamically through its attributes, these are split into two categories: **Static attributes** (e.g. `int`, `str`) are embedded in the tree structure definition and are not traced by JAX. **Data attributes** (e.g. `nnx.Variable`, `jax.Array`) are the leaves of the tree and are traced by JAX. For more details see the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html). + +Here's a typical layer definition: + +```{code-cell} ipython3 +class Count(nnx.Variable): pass # custom Variable types + +class Linear(nnx.Pytree): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.din, self.dout = din, dout # static attributes + self.w = nnx.Param(rngs.uniform((din, dout))) # data attribute + self.count = Count(jnp.array(0)) # data attribute + + def __call__(self, x: jax.Array): + self.count[...] += 1 # inplace state updates + return x @ self.w # Variable are Array-like + +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +nnx.display(model) +``` + +> **Note:** Most user code uses `nnx.Module`, which is a subclass of `nnx.Pytree` with additional features such as sopport for metric reporting. + +As we can see above, Variables are array-like; they support arithmetic operators, indexing, and can be used directly in JAX expressions. You can update their value in-place using `variable[...] = new_value`. Since NNX Pytrees are standard JAX pytrees, you can use `jax.tree.*` functions directly on them: + +```{code-cell} ipython3 +x = jnp.ones((3, 2)) +y = model(x) +print(f'{y.shape = }, {model.count[...] = }') + +# jax.tree.map works directly on NNX Pytrees +doubled_model = jax.tree.map(lambda x: x * 2, model) +print(f'\nmodel.w sum: {model.w.sum():.4f}') +print(f'doubled.w sum: {doubled_model.w.sum():.4f}') + +# jax.tree.leaves_with_path shows the full tree structure +print('\nPytree leaves:') +for path, value in jax.tree.leaves_with_path(model): + print(f'{jax.tree_util.keystr(path)}: {value!r}') +``` + +Here `jax.tree.map` was first used create a new model with each leaf Array doubled, and then `jax.tree.flatten_with_path` was used to show how JAX sees the tree structure. Notice that because Variables are also JAX pytrees containing a single element (their inner value) we see `value` as part of the leaf path. + ++++ + +## Rngs +`nnx.Rngs` simplify managing [JAX PRNG state](https://jax.readthedocs.io/en/latest/random-numbers.html). It is itself an `nnx.Pytree` that stores a seed `key` and an incrementing `counter` in `Variable`s internally. By calling it, `Rngs` can produce new PRNG keys: + +```{code-cell} ipython3 +rngs = nnx.Rngs(0) # seeded with 0 + +key1 = rngs() # get a raw key +key2 = rngs() # different key (counter incremented) +arr = rngs.normal((2, 3)) # draw samples directly + +print(f'{key1 = }') +print(f'{key2 = }') +print(f'{arr = }') +print(rngs) +``` + +As we've seen so far, `Rngs` conveniently exposes every `jax.random.*` distribution as a method (e.g. `rngs.uniform(...)`, `rngs.normal(...)`) without requiring the `key` argument and returning different random values every time they are called, this highly simplifies the user experience. In general `Rngs` can hold multiple keys and counters in structures called `RngStream`s, above we see that the `default` stream is being used. For more information check out the [Randomness guide](https://flax.readthedocs.io/en/latest/guides/randomness.html). + ++++ + +## Nested Modules + +Pytree subclasses compose naturally, you can assign one as an attribute of another to build nested models. The example below builds a simple `MLP` from two `Linear` layers: + +```{code-cell} ipython3 +class MLP(nnx.Pytree): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.din, self.dmid, self.dout = din, dmid, dout # static attributes + self.linear1 = Linear(din, dmid, rngs=rngs) # data attribute + self.linear2 = Linear(dmid, dout, rngs=rngs) # data attribute + + def __call__(self, x: jax.Array): + x = nnx.relu(self.linear1(x)) + return self.linear2(x) + +mlp = MLP(2, 16, 5, rngs=nnx.Rngs(0)) +y = mlp(jnp.ones((3, 2))) +print(f'{y.shape = }') + +nnx.display(mlp) +``` + +Because the entire model is a single pytree, all the `jax.tree.*` functions, JAX transforms, and NNX state APIs work on the full nested structure at once. For more info check out the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html). + ++++ + +## JAX Transforms + +NNX models can be passed directly to JAX transforms like `jax.jit`. However, JAX transforms create pure functions, meaning that they won't propagate side effects such as Variable state updates back to the caller: + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +@jax.jit +def forward(model, x): # pure function + y = model(x) + return y + +y = forward(model, x) + +print(model.count[...]) # no state update +``` + +Here `count` was not updated because inside `jax.jit` new Variable copies are created so any updates inside will not be reflected outside. To propagate updates we can use two NNX helpers. `nnx.state(obj, *filters)` extracts the current state of all Variables in `obj` as a nested `State` dict; you can pass **filters** to select specific Variable types, for example `nnx.state(model, Count)` extracts only `Count` Variables (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for details). `nnx.update(obj, state)` writes a `State` back into the corresponding Variables of `obj`. + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +@jax.jit +def forward(model, x): + y = model(x) + return y, nnx.state(model, Count) # propagate state + +y, updates = forward(model, x) +nnx.update(model, updates) # apply state updates + +print(model.count[...]) # updated successfully +``` + +In this example we could've also chosen to return the entire `model` and replace its reference outside, however the use `nnx.state/update` is preferred as NNX promotes preserving existing Variable references. + ++++ + +### Training step with JAX transforms + +For a full training step we also need to differentiate with respect to some parameters while keeping the rest non-differentiable. `nnx.split` and `nnx.merge` let us partition and reconstruct the model. `nnx.split(obj, *filters)` returns a structure definition (`GraphDef`) followed by one `State` group per filter, where the catch-all filter `...` matches everything not yet matched by a previous filter (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for the full filter language). `nnx.merge(graphdef, *states)` reconstructs a copy of the object from its definition and state groups. We will use these to select the differentiable parameters when passing them to `jax.grad`. + +The example below shows a complete training step using raw JAX transforms. `nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and stores its internal state as Variables, providing a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters: + +```{code-cell} ipython3 +import optax + +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@jax.jit +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + nondiff = nnx.clone(nondiff) # refresh trace state + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss, nnx.state(model, Count) # propagate state + + grads, updates = jax.grad(loss_fn, has_aux=True)(params, nondiff) + nnx.update(model, updates) + optimizer.update(model, grads) + + return nnx.state((model, optimizer)) + +updates = train_step(model, optimizer, x, y) +nnx.update((model, optimizer), updates) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +A few things to note. The state of the `model` and `optimizer` is extracted at once by packing them in a tuple (or any pytree), and `nnx.update` accepts the same structure. By default `jax.grad` differentiates with respect to the first positional argument only, `params` in our case. Finally, `nnx.clone` is needed because `jax.grad` passes non differentiable inputs (here `nondiff`) directly without tracing them, so we must manually clone them to refresh the trace state of their Variables - preventing tracer leakage. Omitting `nnx.clone` raises an error. + ++++ + +## NNX Transforms + +NNX transforms (`nnx.jit`, `nnx.grad`, ...) are thin wrappers over JAX transforms that provide the exact same APIs. Their main feature is **automatic state propagation**: the state of all input Variables is tracked and updated automatically behind the scenes. This removes the need for the `nnx.state/update` boilerplate and the use of `nnx.clone`: + +```{code-cell} ipython3 +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@nnx.jit # automatic state propagation +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss + + grads = nnx.grad(loss_fn)(params, nondiff) + optimizer.update(model, grads) + +train_step(model, optimizer, x, y) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +Notice that `train_step` doesn't need to return anthing as `nnx.jit` propagates all Variable updates (model parameters, optimizer state, counts) automatically. + ++++ + +## Graph Mode + +Certain programs are easier to express by sharing references between objets on different parts of a structure, however this is not compatible with JAX's pytree model. If we create a simple model that shares a reference to the same Variable in two different attributes, NNX transforms and most other APIs will raise an error as sharing can result in inconsistencies: + +```{code-cell} ipython3 +@nnx.dataclass +class Foo(nnx.Module): + a: nnx.Param + b: nnx.Param + +p = nnx.Param(jnp.array(1.0)) +model = Foo(p, p) # shared Param + +@nnx.jit +def forward(model, x): + model.a[...] += 1.0 + return model.a * x + model.b + +try: + forward(model, jnp.array(1.0)) +except ValueError as e: + print(f'Error: {e}') +``` + +However, at the cost of some python overhead, `graph=True` can be passed to NNX APIs to enable **graph mode**. In graph mode, general graphs structures are allowed as long as they Variables are transformed consistently. We can fix the above example by enabling graph mode in `nnx.jit`: + +```{code-cell} ipython3 +@nnx.jit(graph=True) +def forward(model, x): + model.a[...] += 1.0 + return model.a * x + model.b + +y = forward(model, jnp.array(1.0)) + +print(f'{y = !s}, {model.a[...] = !s}, {model.b[...] = !s}') +``` + +## Hijax (experimental) + +JAX's experimental **Hijax** API allows custom mutable types whose state updates propagate automatically through JAX transforms. When enabled via `nnx.var_default(hijax=True)`, plain JAX transforms like `jax.jit` handle state propagation of `Variable`s without any manual `nnx.state` / `nnx.update` calls. As a bonus, in hijax mode Variables can also be passed as captures, further simplifying the loss function: + +```{code-cell} ipython3 +with nnx.var_defaults(hijax=True): # enables Hijax Variables + x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) + model = Linear(2, 5, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +print(model) # display Hijax Variables + +@jax.jit # automatic state propagation +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params): + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss + + grads = jax.grad(loss_fn)(nnx.vars_as(params, hijax=False)) # disable hijax for param grads + optimizer.update(model, grads) + +train_step(model, optimizer, x, y) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +As a temporary limitation, `jax.grad` does not yet handle mutable Hijax types. We work around this by converting `params` to regular Variables via `nnx.vars_as(params, hijax=False)` before passing them to `grad`. Hijax can also be enabled on a per-Variable basis by passing `hijax=True` to the constructor: + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(1), hijax=True) + +@jax.jit +def inc(v): + v[...] += 1 + +print(f'{v[...] = !s}') +inc(v) +print(f'{v[...] = !s}') +``` diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4004b02c0..37e2247af 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -25,7 +25,7 @@ from flax.nnx.pytreelib import Pytree from flax.nnx.variablelib import Variable -M = tp.TypeVar('M', bound=nnx.Module) +M = tp.TypeVar('M') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) class OptState(Variable): diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index db179a9af..a70936055 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -395,6 +395,53 @@ def constrain_object(m): self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) + def test_state_sharding_with_variable_types(self): + """Test StateSharding with different Variable types.""" + + class Count(nnx.Variable): + pass + + class Weights(nnx.Module): + def __init__(self, kernel, count): + self.kernel = nnx.Param(kernel) + self.count = Count(count) + + # Use multiple CPU devices for testing + n_devices = min(jax.local_device_count(), 8) + if n_devices < 2: + self.skipTest('Test requires at least 2 devices') + + rngs = nnx.Rngs(1) + devices = mesh_utils.create_device_mesh((n_devices,)) + mesh = jax.sharding.Mesh(devices, ('devices',)) + + def sharding(*args): + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)) + + weights = Weights( + kernel=rngs.uniform((16, 16)), + count=jnp.array(0), + ) + x = jnp.ones((16, 16)) + + # Define sharding for different Variable types + state_sharding = nnx.StateSharding( + { + nnx.Param: sharding(None, 'devices'), # shard Param + Count: sharding(), # replicate Count + } + ) + + @nnx.graph.jit(in_shardings=(state_sharding, sharding('devices'))) + def forward(weights, x): + weights.count[...] += 1 + return x @ weights.kernel + + y = forward(weights, x) + + self.assertEqual(y.shape, (16, 16)) + self.assertEqual(weights.count[...], 1) + def test_cache_args(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))