From d7474269a4a93be273eaa7485240a0a64d934220 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Wed, 25 Mar 2026 15:49:51 -0500 Subject: [PATCH 1/5] Rename NNX view functions to have common naming convention. --- docs_nnx/guides/view.ipynb | 18 +++---- docs_nnx/guides/view.md | 18 +++---- docs_nnx/hijax/hijax.ipynb | 20 ++++---- docs_nnx/hijax/hijax.md | 20 ++++---- docs_nnx/mnist_tutorial.ipynb | 8 ++-- docs_nnx/mnist_tutorial.md | 8 ++-- examples/nnx_toy_examples/hijax_basic.py | 2 +- examples/nnx_toy_examples/hijax_demo.py | 2 +- flax/nnx/__init__.py | 8 ++-- flax/nnx/compat.py | 2 +- flax/nnx/graphlib.py | 16 +++---- flax/nnx/module.py | 2 +- flax/nnx/spmd.py | 2 +- flax/nnx/training/optimizer.py | 8 ++-- tests/nnx/integration_test.py | 10 ++-- tests/nnx/module_test.py | 6 +-- tests/nnx/mutable_array_test.py | 61 ++++++++++++------------ tests/nnx/nn/stochastic_test.py | 6 +-- tests/nnx/spmd_test.py | 6 +-- tests/nnx/transforms_test.py | 6 +-- 20 files changed, 114 insertions(+), 115 deletions(-) diff --git a/docs_nnx/guides/view.ipynb b/docs_nnx/guides/view.ipynb index c864f9e56..4b5378fb7 100644 --- a/docs_nnx/guides/view.ipynb +++ b/docs_nnx/guides/view.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "source": [ "# Model Views\n", - "This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:" + "This guide covers how to use NNX \"Views\", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes." ] }, { @@ -25,8 +25,8 @@ ")\n", "\n", "# set train and eval modes\n", - "train_model = nnx.view(model, deterministic=False, use_running_average=False)\n", - "eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n", + "train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)\n", + "eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)\n", "\n", "# Can see deterministic is different between train_model and eval_model\n", "assert train_model.layers[2].deterministic == False\n", @@ -35,7 +35,7 @@ "# Weights are shared between the models\n", "assert train_model.layers[0].kernel is eval_model.layers[0].kernel\n", "\n", - "# Print information about kwargs for nnx.view with nnx.view_info\n", + "# Print information about kwargs for nnx.with_modules with nnx.view_info\n", "print(nnx.view_info(model))" ] }, @@ -125,8 +125,8 @@ "metadata": {}, "outputs": [], "source": [ - "train_model = nnx.view(model, deterministic=False)\n", - "eval_model = nnx.view(model, deterministic=True)\n", + "train_model = nnx.with_modules(model, deterministic=False)\n", + "eval_model = nnx.with_modules(model, deterministic=True)\n", "\n", "# weights are references to the same data\n", "assert train_model.lin1.kernel is eval_model.lin1.kernel\n", @@ -196,8 +196,8 @@ "source": [ "model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs)\n", "optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)\n", - "train_model = nnx.view(model, deterministic=False) # training view\n", - "eval_model = nnx.view(model, deterministic=True) # eval view\n", + "train_model = nnx.with_modules(model, deterministic=False) # training view\n", + "eval_model = nnx.with_modules(model, deterministic=True) # eval view\n", "\n", "eval_results = []\n", "for epoch in range(total_epochs):\n", @@ -293,7 +293,7 @@ "\n", "\n", "model = PrintLayer()\n", - "model_print = nnx.view(model, msg='Hello, World!')\n", + "model_print = nnx.with_modules(model, msg='Hello, World!')\n", "\n", "model() # nothing printed\n", "model_print() # prints \"Hello, World!\"" diff --git a/docs_nnx/guides/view.md b/docs_nnx/guides/view.md index d2fb3f8d4..31ad97ae4 100644 --- a/docs_nnx/guides/view.md +++ b/docs_nnx/guides/view.md @@ -9,7 +9,7 @@ jupytext: --- # Model Views -This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example: +This guide covers how to use NNX "Views", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes. ```{code-cell} from flax import nnx @@ -21,8 +21,8 @@ model = nnx.Sequential( ) # set train and eval modes -train_model = nnx.view(model, deterministic=False, use_running_average=False) -eval_model = nnx.view(model, deterministic=True, use_running_average=True) +train_model = nnx.with_modules(model, deterministic=False, use_running_average=False) +eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True) # Can see deterministic is different between train_model and eval_model assert train_model.layers[2].deterministic == False @@ -31,7 +31,7 @@ assert eval_model.layers[2].deterministic == True # Weights are shared between the models assert train_model.layers[0].kernel is eval_model.layers[0].kernel -# Print information about kwargs for nnx.view with nnx.view_info +# Print information about kwargs for nnx.with_modules with nnx.view_info print(nnx.view_info(model)) ``` @@ -85,8 +85,8 @@ From the model display, we can see that `Dropout` has `deterministic == False`, This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below. ```{code-cell} -train_model = nnx.view(model, deterministic=False) -eval_model = nnx.view(model, deterministic=True) +train_model = nnx.with_modules(model, deterministic=False) +eval_model = nnx.with_modules(model, deterministic=True) # weights are references to the same data assert train_model.lin1.kernel is eval_model.lin1.kernel @@ -128,8 +128,8 @@ Now we create `train_model` and `eval_model` views up front. During the training ```{code-cell} model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs) optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param) -train_model = nnx.view(model, deterministic=False) # training view -eval_model = nnx.view(model, deterministic=True) # eval view +train_model = nnx.with_modules(model, deterministic=False) # training view +eval_model = nnx.with_modules(model, deterministic=True) # eval view eval_results = [] for epoch in range(total_epochs): @@ -201,7 +201,7 @@ class PrintLayer(nnx.Module): model = PrintLayer() -model_print = nnx.view(model, msg='Hello, World!') +model_print = nnx.with_modules(model, msg='Hello, World!') model() # nothing printed model_print() # prints "Hello, World!" diff --git a/docs_nnx/hijax/hijax.ipynb b/docs_nnx/hijax/hijax.ipynb index 4b1c8d40f..bcb7c83ce 100644 --- a/docs_nnx/hijax/hijax.ipynb +++ b/docs_nnx/hijax/hijax.ipynb @@ -49,7 +49,7 @@ "@jax.jit\n", "def train_step(x, y):\n", " loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n", - " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n", + " loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad\n", " optimizer.update(model, grads)\n", " return loss\n", "\n", @@ -297,8 +297,8 @@ "\n", "model = Linear(1, 3, rngs=nnx.Rngs(0))\n", "\n", - "print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n", - "print(f\"{nnx.vars_as(model, mutable=True) = !s}\")" + "print(f\"{nnx.with_vars(model, mutable=False) = !s}\")\n", + "print(f\"{nnx.with_vars(model, mutable=True) = !s}\")" ] }, { @@ -317,7 +317,7 @@ ], "source": [ "v = nnx.Variable(jnp.array(0))\n", - "v_immut = nnx.vars_as(v, mutable=False)\n", + "v_immut = nnx.with_vars(v, mutable=False)\n", "assert not v_immut.mutable\n", "\n", "try:\n", @@ -355,7 +355,7 @@ ], "source": [ "v = nnx.Variable(jnp.array(0))\n", - "v_ref = nnx.vars_as(v, ref=True)\n", + "v_ref = nnx.with_vars(v, ref=True)\n", "assert v_ref.ref\n", "print(v_ref)\n", "print(v_ref.get_raw_value())" @@ -386,11 +386,11 @@ } ], "source": [ - "v_immut = nnx.vars_as(v_ref, mutable=False)\n", + "v_immut = nnx.with_vars(v_ref, mutable=False)\n", "assert not v_immut.ref\n", "print(\"immutable =\", v_immut)\n", "\n", - "v_ref = nnx.vars_as(v_immut, mutable=True)\n", + "v_ref = nnx.with_vars(v_immut, mutable=True)\n", "assert v_ref.ref\n", "print(\"mutable =\", v_ref)" ] @@ -458,7 +458,7 @@ " model = nnx.merge(graphdef, params, nondiff)\n", " return ((model(x) - y) ** 2).mean()\n", "\n", - " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n", + " loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad\n", " optimizer.update(model, grads)\n", "\n", " return loss\n", @@ -563,9 +563,9 @@ "source": [ "@jax.jit\n", "def create_model(rngs):\n", - " return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", + " return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", "\n", - "model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n", + "model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)\n", "\n", "print(\"model.linear =\", model.linear)" ] diff --git a/docs_nnx/hijax/hijax.md b/docs_nnx/hijax/hijax.md index 6d6123022..be1a65796 100644 --- a/docs_nnx/hijax/hijax.md +++ b/docs_nnx/hijax/hijax.md @@ -29,7 +29,7 @@ optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param) @jax.jit def train_step(x, y): loss_fn = lambda m: jnp.mean((m(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad + loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad optimizer.update(model, grads) return loss @@ -112,13 +112,13 @@ class Linear(nnx.Module): model = Linear(1, 3, rngs=nnx.Rngs(0)) -print(f"{nnx.vars_as(model, mutable=False) = !s}") -print(f"{nnx.vars_as(model, mutable=True) = !s}") +print(f"{nnx.with_vars(model, mutable=False) = !s}") +print(f"{nnx.with_vars(model, mutable=True) = !s}") ``` ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0)) -v_immut = nnx.vars_as(v, mutable=False) +v_immut = nnx.with_vars(v, mutable=False) assert not v_immut.mutable try: @@ -131,18 +131,18 @@ except Exception as e: ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0)) -v_ref = nnx.vars_as(v, ref=True) +v_ref = nnx.with_vars(v, ref=True) assert v_ref.ref print(v_ref) print(v_ref.get_raw_value()) ``` ```{code-cell} ipython3 -v_immut = nnx.vars_as(v_ref, mutable=False) +v_immut = nnx.with_vars(v_ref, mutable=False) assert not v_immut.ref print("immutable =", v_immut) -v_ref = nnx.vars_as(v_immut, mutable=True) +v_ref = nnx.with_vars(v_immut, mutable=True) assert v_ref.ref print("mutable =", v_ref) ``` @@ -176,7 +176,7 @@ def train_step(model, optimizer, x, y): model = nnx.merge(graphdef, params, nondiff) return ((model(x) - y) ** 2).mean() - loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad + loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad optimizer.update(model, grads) return loss @@ -226,9 +226,9 @@ except Exception as e: ```{code-cell} ipython3 @jax.jit def create_model(rngs): - return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False) + return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False) -model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True) +model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True) print("model.linear =", model.linear) ``` diff --git a/docs_nnx/mnist_tutorial.ipynb b/docs_nnx/mnist_tutorial.ipynb index df3b8373d..557d92df1 100644 --- a/docs_nnx/mnist_tutorial.ipynb +++ b/docs_nnx/mnist_tutorial.ipynb @@ -303,7 +303,7 @@ "\n", "## 6. Train and evaluate the model\n", "\n", - "Now, you can train the CNN model. Before the training loop, we use [`nnx.view`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation." + "Now, you can train the CNN model. Before the training loop, we use [`nnx.with_modules`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation." ] }, { @@ -335,8 +335,8 @@ "}\n", "\n", "rngs = nnx.Rngs(0)\n", - "train_model = nnx.view(model, deterministic=False, use_running_average=False)\n", - "eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n", + "train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)\n", + "eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)\n", "\n", "for step, batch in enumerate(train_ds.as_numpy_iterator()):\n", " # Run the optimization for one step and make a stateful update to the following:\n", @@ -380,7 +380,7 @@ "source": [ "## 7. Perform inference on the test set\n", "\n", - "Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (an `nnx.view` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." + "Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (using `nnx.with_modules` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." ] }, { diff --git a/docs_nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md index cb04156ad..f34a28578 100644 --- a/docs_nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -173,7 +173,7 @@ In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_ref ## 6. Train and evaluate the model -Now, you can train the CNN model. Before the training loop, we use [`nnx.view`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation. +Now, you can train the CNN model. Before the training loop, we use [`nnx.with_modules`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation. ```{code-cell} ipython3 from IPython.display import clear_output @@ -187,8 +187,8 @@ metrics_history = { } rngs = nnx.Rngs(0) -train_model = nnx.view(model, deterministic=False, use_running_average=False) -eval_model = nnx.view(model, deterministic=True, use_running_average=True) +train_model = nnx.with_modules(model, deterministic=False, use_running_average=False) +eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True) for step, batch in enumerate(train_ds.as_numpy_iterator()): # Run the optimization for one step and make a stateful update to the following: @@ -227,7 +227,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()): ## 7. Perform inference on the test set -Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (an `nnx.view` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. +Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (using `nnx.with_modules` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. ```{code-cell} ipython3 @nnx.jit diff --git a/examples/nnx_toy_examples/hijax_basic.py b/examples/nnx_toy_examples/hijax_basic.py index 9ffe2b67d..fde8a76d2 100644 --- a/examples/nnx_toy_examples/hijax_basic.py +++ b/examples/nnx_toy_examples/hijax_basic.py @@ -68,7 +68,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((y - model(x)) ** 2) - grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False)) + grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False)) optimizer.update(model, grads) @jax.jit diff --git a/examples/nnx_toy_examples/hijax_demo.py b/examples/nnx_toy_examples/hijax_demo.py index 5380fc592..78863483a 100644 --- a/examples/nnx_toy_examples/hijax_demo.py +++ b/examples/nnx_toy_examples/hijax_demo.py @@ -238,7 +238,7 @@ def loss_fn(params): # For the time being we have to use 'immutable' # as 'jax.grad' doesn't support QDD types yet. - grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False)) + grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False)) # 'update' mutates the optimizer's state and the params in place # so we don't need to return anything 🚀 optimizer.update(params, grads) diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index ac0fd2391..134931825 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -50,7 +50,7 @@ from .module import M as M from .module import Module as Module from .module import capture as capture -from .module import view as view +from .module import with_modules as with_modules from .module import view_info as view_info from .module import with_attributes as with_attributes from .module import iter_children as iter_children, iter_modules as iter_modules @@ -75,8 +75,8 @@ from .graphlib import MergeContext as MergeContext from .graphlib import merge_context as merge_context from .graphlib import variables as variables -from .graphlib import vars_as as vars_as -from .graphlib import pure as pure +from .graphlib import with_vars as with_vars +from .graphlib import as_pure as as_pure from .graphlib import cached_partial as cached_partial from .graphlib import flatten as flatten from .graphlib import unflatten as unflatten @@ -152,7 +152,7 @@ from .spmd import get_named_sharding as get_named_sharding from .spmd import with_partitioning as with_partitioning from .spmd import get_abstract_model as get_abstract_model -from .spmd import abstract_with_sharding as abstract_with_sharding +from .spmd import as_abstract as as_abstract from .statelib import FlatState as FlatState from .statelib import State as State from .statelib import to_flat_state as to_flat_state diff --git a/flax/nnx/compat.py b/flax/nnx/compat.py index 78a8a3432..035357e18 100644 --- a/flax/nnx/compat.py +++ b/flax/nnx/compat.py @@ -39,7 +39,7 @@ recursive_map = functools.partial(_graphlib.recursive_map, graph=True) # module -view = functools.partial(_module.view, graph=True) +view = functools.partial(_module.with_modules, graph=True) view_info = functools.partial(_module.view_info, graph=True) iter_modules = functools.partial(_module.iter_modules, graph=True) iter_children = functools.partial(_module.iter_children, graph=True) # type: ignore[has-type] diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index ed1095dbb..c368d8c74 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -33,6 +33,7 @@ ) from flax.nnx.statelib import FlatState, State, map_state from flax.nnx.variablelib import Variable, is_array_ref, V +import flax.nnx.graphlib as graphlib from flax.typing import BaseConfigContext, HashableMapping, Key, PathParts, is_key_like import jax import numpy as np @@ -2705,7 +2706,7 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N return merge(graphdef, state, copy=variables) -def vars_as( +def with_vars( node: A, /, *, @@ -2714,6 +2715,7 @@ def vars_as( mutable: bool | None = None, only: filterlib.Filter = ..., allow_duplicates: bool = False, + graph: bool | None = False, ) -> A: """ """ new_attrs: dict[str, bool] = {} @@ -2743,20 +2745,18 @@ def _different_vars(path, x): duplicates_strs += '\n ---' raise ValueError(f'Found duplicate at paths:{duplicates_strs}') - def _to_refs(jax_path, x): - if predicate(jax_to_nnx_path(jax_path), x): + def _to_refs(path, x): + if predicate(path, x): assert isinstance(x, Variable) variable = x.copy(**new_attrs) return variable return x - node = jax.tree.map_with_path( - _to_refs, node, is_leaf=lambda x: isinstance(x, Variable) - ) + node = graphlib.map(_to_refs, node, graph=graph) return node -def pure(tree: A) -> A: +def as_pure(tree: A) -> A: """Returns a new tree with all ``Variable`` objects replaced with inner values. This can be used to remove Variable metadata when its is not needed for tasks like @@ -2795,7 +2795,7 @@ def pure(tree: A) -> A: def _pure_fn(x): if isinstance(x, Variable): - return pure(x.get_raw_value()) + return as_pure(x.get_raw_value()) elif variablelib.is_array_ref(x): return x[...] return x diff --git a/flax/nnx/module.py b/flax/nnx/module.py index a0eeb857c..5dc30133b 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -437,7 +437,7 @@ def eval(self, **attributes): raise_if_not_found=False, ) -def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, graph: bool | None = None, **kwargs) -> A: +def with_modules(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, graph: bool | None = None, **kwargs) -> A: """Creates a new node with static attributes updated according to ``**kwargs``. The new node contains references to jax arrays in the original node. If a diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index c7914a4f6..56c07b15d 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -182,7 +182,7 @@ def get_abstract_model(init_fn, mesh, *, graph: bool | None = None): return gdef, abs_state -def abstract_with_sharding( +def as_abstract( tree: A, graph: bool | None = None ) -> A: """Add sharding information to abstract Variables. diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4004b02c0..7e25bff64 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -200,10 +200,10 @@ def update(self, model: M, grads, /, **kwargs): **kwargs: additional keyword arguments passed to the tx.update, to support ``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``. """ - param_arrays = nnx.pure(nnx.state(model, self.wrt)) - grad_arrays = nnx.pure(nnx.state(grads, self.wrt)) - opt_state_arrays = nnx.pure(self.opt_state) - kwargs_arrays = nnx.pure(kwargs) + param_arrays = nnx.as_pure(nnx.state(model, self.wrt)) + grad_arrays = nnx.as_pure(nnx.state(grads, self.wrt)) + opt_state_arrays = nnx.as_pure(self.opt_state) + kwargs_arrays = nnx.as_pure(kwargs) updates, new_opt_state = self.tx.update( grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 6f16842cb..ec1a384c3 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -45,10 +45,10 @@ def __call__(self, x): return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization - train_model = nnx.view( + train_model = nnx.with_modules( model, deterministic=False, use_running_average=False ) - eval_model = nnx.view( + eval_model = nnx.with_modules( model, deterministic=True, use_running_average=True ) optimizer = nnx.Optimizer(train_model, optax.adam(1e-3), wrt=nnx.Param) @@ -174,7 +174,7 @@ def loss_fn(model: Model): x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) - new_model = nnx.view(model, use_running_average=False) + new_model = nnx.with_modules(model, use_running_average=False) for _i in range(3): train_step(model, x, y) @@ -267,7 +267,7 @@ def __call__(self, x): @jax.jit def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): model = nnx.merge(graphdef, state) - new_model = nnx.view(model, use_running_average=False, graph=True) + new_model = nnx.with_modules(model, use_running_average=False, graph=True) @nnx.grad def loss_fn(model: Model): @@ -487,7 +487,7 @@ def loss_fn(params): return ((model(x) - y) ** 2).mean() # call methods directly loss, grads = jax.value_and_grad(loss_fn)( - nnx.vars_as(params, hijax=False) + nnx.with_vars(params, hijax=False) ) optimizer.update(model, grads) # in-place updates diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 2706fa53a..161c2df79 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -854,13 +854,13 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): assert block.dropout.deterministic == False assert block.batch_norm.use_running_average == False - new_block = nnx.view(block, deterministic=True, use_running_average=True, graph=graph) + new_block = nnx.with_modules(block, deterministic=True, use_running_average=True, graph=graph) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == True assert new_block.linear.kernel is block.linear.kernel block = Block(2, 5, rngs=nnx.Rngs(0)) - new_block = nnx.view(block, only=nnx.Dropout, deterministic=True, graph=graph) + new_block = nnx.with_modules(block, only=nnx.Dropout, deterministic=True, graph=graph) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == False @@ -951,7 +951,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): "Unused keys found in nnx.view: \\['unknown'\\]" ), ): - nnx.view(block, deterministic=True, use_running_average=True, unknown=True, graph=graph) + nnx.with_modules(block, deterministic=True, use_running_average=True, unknown=True, graph=graph) def test_cloud_pickle(self): import platform diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index 570d6910b..6597d1f20 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -29,7 +29,7 @@ def __init__(self): self.node = jnp.array(1) self.meta = 1 - m = nnx.vars_as(Foo(), ref=True) + m = nnx.with_vars(Foo(), ref=True) m = jax.tree.map(lambda x: x + 1, m) @@ -126,17 +126,17 @@ def test_split_mutable_array(self): def test_to_arrays_example(self): node = [nnx.Variable(1.0), nnx.Variable(2.0, mode='ref')] - mutable_node = nnx.vars_as(node, ref=True) + mutable_node = nnx.with_vars(node, ref=True) assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) assert isinstance(mutable_node[1].get_raw_value(), jax.Ref) shared_array = nnx.Variable(1.0, mode='pytree') node = [shared_array, shared_array] with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): - nnx.vars_as(node, ref=True) + nnx.with_vars(node, ref=True) node = [nnx.Variable(1.0), nnx.Variable(2.0)] - mutable_node = nnx.vars_as( + mutable_node = nnx.with_vars( node, ref=True, only=lambda path, x: path[0] == 0 ) assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) @@ -148,18 +148,18 @@ def __init__(self): self.a = nnx.Param(1) self.b = nnx.BatchStat(2) - m = nnx.vars_as(Foo(), hijax=True, ref=True) + m = nnx.with_vars(Foo(), hijax=True, ref=True) self.assertEqual(m.a.ref, True) self.assertEqual(m.b.ref, True) - m2 = nnx.vars_as(m, hijax=False, only=nnx.BatchStat) + m2 = nnx.with_vars(m, hijax=False, only=nnx.BatchStat) self.assertEqual(m2.a.ref, True) self.assertEqual(m2.a.hijax, True) self.assertEqual(m2.b.ref, True) self.assertEqual(m2.b.hijax, False) self.assertIsNot(m, m2) - m3 = nnx.vars_as(m2, hijax=True, only=nnx.BatchStat) + m3 = nnx.with_vars(m2, hijax=True, only=nnx.BatchStat) self.assertEqual(m3.a.ref, True) self.assertEqual(m3.b.ref, True) self.assertEqual(m3.b.hijax, True) @@ -175,7 +175,7 @@ def __init__(self): m = Foo() with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): - nnx.vars_as(m, ref=True) + nnx.with_vars(m, ref=True) def test_mutable_array_split(self): class Foo(nnx.Module): @@ -233,7 +233,7 @@ def test_mutable_example(self): tree = [nnx.Variable(1.0), nnx.Variable(2.0, ref=True)] assert tree[0].ref == False assert tree[1].ref == True - mutable_tree = nnx.vars_as(tree, ref=True) + mutable_tree = nnx.with_vars(tree, ref=True) assert isinstance(mutable_tree[0].get_raw_value(), jax.Ref) assert isinstance(mutable_tree[1].get_raw_value(), jax.Ref) @@ -247,7 +247,7 @@ def __init__(self): ref_map = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_map, graph=True) - state = nnx.vars_as(state, hijax=False) + state = nnx.with_vars(state, hijax=False) self.assertLen(state, 1) m1 = nnx.merge(graphdef, state) @@ -255,7 +255,7 @@ def __init__(self): self.assertIsInstance(m1.a, jax.Ref) def test_update_context(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.split(m1) @@ -263,7 +263,7 @@ def test_update_context(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.split((m2, m_out1, m2)) @@ -290,7 +290,7 @@ def test_update_context(self): self.assertIsNot(m_out2, m_out1) def test_update_context_flatten(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.flatten(m1) @@ -298,7 +298,7 @@ def test_update_context_flatten(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.flatten((m2, m_out1, m2)) @@ -327,13 +327,13 @@ def test_update_context_flatten(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree1(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) @@ -366,13 +366,13 @@ def test_update_context_to_tree1(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree2(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example') as ctx: m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) @@ -405,13 +405,13 @@ def test_update_context_to_tree2(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree_trivial_prefix(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example', prefix=0) (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True, prefix=0) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) @@ -444,13 +444,13 @@ def test_update_context_to_tree_trivial_prefix(self): self.assertIsNot(m_out2, m_out1) def test_simple_jit(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) m_out1 = None @nnx.jit def f(m2): nonlocal m_out1 - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) return m_out1 m_out2 = f(m1) @@ -516,7 +516,7 @@ def __init__(self, din: int, dout: int): self.count = nnx.Variable(jnp.array(0)) params = Params(3, 4) - params = nnx.vars_as(params, ref=True) + params = nnx.with_vars(params, ref=True) paths_leaves, treedef = jax.tree.flatten_with_path(params) paths, leaves = zip(*paths_leaves) @@ -682,7 +682,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, hijax=False)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, hijax=False)) optimizer.update(params, grads) return loss @@ -693,7 +693,7 @@ def loss_fn(params): class TestHijaxVariables(parameterized.TestCase): def test_variable_to_hijax(self): v_low = nnx.Param(jnp.array(1), a='hi') - v_hi = nnx.vars_as(v_low, hijax=True) + v_hi = nnx.with_vars(v_low, hijax=True) self.assertTrue(v_hi.hijax) self.assertEqual(v_hi[...], 1) @@ -715,7 +715,7 @@ def set(v_hi, a): self.assertEqual(v_hi[...], 15) self.assertEqual(y, 17) - v_low = nnx.vars_as(v_hi, hijax=False) + v_low = nnx.with_vars(v_hi, hijax=False) self.assertIsInstance(v_low, nnx.Param) self.assertFalse(v_low.hijax) self.assertEqual(v_low[...], 15) @@ -741,7 +741,7 @@ def test_variable_to_hijax_clean(self): print() print(v_low) assert not v_low.hijax - v_hi = nnx.vars_as(v_low, hijax=True) + v_hi = nnx.with_vars(v_low, hijax=True) v_hi[...] = jnp.array([2]) assert v_hi.hijax print(v_hi) @@ -757,7 +757,7 @@ def set(v_hi, a): assert v_hi[...] == 10 - v_low = nnx.vars_as(v_hi, hijax=False) + v_low = nnx.with_vars(v_hi, hijax=False) assert not v_low.hijax assert v_low[...] == 10 @@ -801,7 +801,7 @@ def __init__(self, din, dout, rngs: nnx.Rngs): assert not foo.w.hijax assert not foo.b.hijax - foo = nnx.vars_as(foo, hijax=True) + foo = nnx.with_vars(foo, hijax=True) assert foo.w.hijax assert foo.b.hijax @@ -949,7 +949,7 @@ def test_variable_copy_properties(self, hijax, ref): def test_variable_vars_as_properties(self, hijax, ref): v_original = nnx.Variable(jnp.array(1)) - v = nnx.vars_as(v_original, hijax=hijax, ref=ref) + v = nnx.with_vars(v_original, hijax=hijax, ref=ref) self.assertEqual(v.hijax, hijax) self.assertEqual(v.ref, ref) if hijax: @@ -1213,4 +1213,3 @@ def f(v, v2): if __name__ == '__main__': absltest.main() - diff --git a/tests/nnx/nn/stochastic_test.py b/tests/nnx/nn/stochastic_test.py index 296776de0..9ab81fd39 100644 --- a/tests/nnx/nn/stochastic_test.py +++ b/tests/nnx/nn/stochastic_test.py @@ -95,7 +95,7 @@ def test_dropout_arg_override_view(self): # deterministic call arg provided m(x, deterministic=True) # deterministic constructor arg provided - new_m = nnx.view(m, deterministic=True) + new_m = nnx.with_modules(m, deterministic=True) y = new_m(x) # both deterministic call and constructor arg provided with pytest.raises(AssertionError): @@ -103,9 +103,9 @@ def test_dropout_arg_override_view(self): y, new_m(x, deterministic=False, rngs=nnx.Rngs(dropout=0)) ) # no rng arg provided - new_m = nnx.view(m, deterministic=False) + new_m = nnx.with_modules(m, deterministic=False) with pytest.raises( ValueError, match='`deterministic` is False, but no `rngs` argument was provided to Dropout', ): - new_m(x) \ No newline at end of file + new_m(x) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 9b81e680e..23e727dc6 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -496,7 +496,7 @@ def test_get_abstract_with_abstract_mesh(self): kernel_metadata={'out_sharding': ('a', 'b')}, ) ) - abs_model = nnx.abstract_with_sharding(abs_model) + abs_model = nnx.as_abstract(abs_model) self.assertIsInstance(abs_model.kernel, nnx.Param) self.assertEqual(abs_model.kernel.sharding.spec, P('a', 'b')) @@ -533,7 +533,7 @@ def __init__(self): ) abs_model = nnx.eval_shape(lambda: Model()) - abs_model = nnx.abstract_with_sharding(abs_model) + abs_model = nnx.as_abstract(abs_model) self.assertEqual(abs_model.p1.kernel.sharding.spec, P('a', 'b')) self.assertEqual(abs_model.p1.kernel.sharding.mesh, mesh1) @@ -542,7 +542,7 @@ def __init__(self): def test_get_abstract_no_sharding_metadata(self): abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0))) - abs_model = nnx.abstract_with_sharding(abs_model) + abs_model = nnx.as_abstract(abs_model) self.assertIsInstance(abs_model.kernel, nnx.Param) self.assertIsNone( diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 1e9bf06e8..4acb7a82d 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -3054,7 +3054,7 @@ def __call__(self, x: jax.Array): return x, None module = MLP(rngs=nnx.Rngs(0)) - new_module = nnx.view(module, deterministic=False, use_running_average=False) + new_module = nnx.with_modules(module, deterministic=False, use_running_average=False) assert new_module.linear.kernel.shape == (5, 3, 3) assert new_module.linear.bias.shape == (5, 3) @@ -3120,7 +3120,7 @@ def __call__(self, x: jax.Array): return x, None module = MLP(rngs=nnx.Rngs(params=0, dropout=1)) - new_module = nnx.view(module, deterministic=False, use_running_average=False) + new_module = nnx.with_modules(module, deterministic=False, use_running_average=False) assert new_module.linear.kernel.shape == (5, 3, 3) assert new_module.linear.bias.shape == (5, 3) @@ -3188,7 +3188,7 @@ def __call__(self, x: jax.Array): return x, None module = Block(rngs=nnx.Rngs(0)) - new_module = nnx.view(module, deterministic=False, use_running_average=False) + new_module = nnx.with_modules(module, deterministic=False, use_running_average=False) assert new_module.d == 3 assert new_module.linear.kernel.shape == (5, 3, 3) From 96e05c5ef4e9e760c59b248577c2bfe1b97a67d2 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Wed, 25 Mar 2026 15:46:11 -0500 Subject: [PATCH 2/5] Add flag for variables in nnx.merge --- flax/nnx/graphlib.py | 12 +++++++++--- flax/nnx/pytreelib.py | 2 +- tests/nnx/graph_utils_test.py | 8 ++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index c368d8c74..655e235de 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -1089,6 +1089,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref: IndexMap | None = None, outer_index_outer_ref: IndexMap | None = None, copy_variables: bool = False, + auto_create_variables: bool = True, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -1150,6 +1151,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref, outer_index_outer_ref, copy_variables, + auto_create_variables ) try: @@ -1171,6 +1173,7 @@ def _graph_unflatten( index_ref: IndexMap, outer_index_outer_ref: IndexMap | None, copy_variables: bool, + auto_create_variables: bool ) -> Node: """Recursive helper for graph_unflatten. @@ -1265,7 +1268,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): variable.set_raw_value(value) else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one - if isinstance(value, Variable): + if isinstance(value, Variable) or not auto_create_variables: variable = value else: variable = variabledef.type.from_metadata( @@ -1314,6 +1317,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: index_ref, outer_index_outer_ref, copy_variables, + auto_create_variables ) else: raise RuntimeError(f'Unknown node definition: {node_def!r}') @@ -2359,6 +2363,7 @@ def merge( # type: ignore[invalid-annotation] /, *states: tp.Any, copy: bool = False, + auto_create_variables: bool = True, ) -> A: """The inverse of :func:`flax.nnx.split`. @@ -2410,7 +2415,7 @@ def merge( # type: ignore[invalid-annotation] _state = state else: _state = _merge_to_flat_state((state, *states)) - node = unflatten(graphdef, _state, copy_variables=copy) + node = unflatten(graphdef, _state, copy_variables=copy, auto_create_variables=auto_create_variables) return node @@ -2534,6 +2539,7 @@ def map( /, *, graph: bool | None = None, + auto_create_variables: bool = True, ) -> A: """Map a function over the state of a graph node. @@ -2567,7 +2573,7 @@ def map( """ graphdef, state = split(node, graph=graph) state = statelib.map_state(f, state) - return merge(graphdef, state) + return merge(graphdef, state, auto_create_variables=auto_create_variables) def graphdef( diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index c53a71314..593205779 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -1061,4 +1061,4 @@ def _maybe_int(x): return x def _get_str(x): - return x if isinstance(x, str) else str(x) \ No newline at end of file + return x if isinstance(x, str) else str(x) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 97bbf4b81..fe1429695 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1699,6 +1699,14 @@ def test_map_replace(self): np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3))) np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,))) + def test_map_auto_create_variables_false(self): + rngs = nnx.Rngs(0) + new_rngs = nnx.map( + lambda path, x: 0, rngs, auto_create_variables=False + ) + self.assertNotIsInstance(new_rngs.default.count, nnx.Variable) + self.assertEqual(new_rngs.default.count, 0) + if __name__ == '__main__': absltest.main() From fbf5dbe1750e9b530f94d13db51342e83e556e4c Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 27 Mar 2026 14:20:02 -0500 Subject: [PATCH 3/5] Make all NNX view functions use recursive_map or nnx.map (to handle graphs) --- docs_nnx/guides/view.ipynb | 110 ++++++++++++++++++++++++++++++---- docs_nnx/guides/view.md | 82 ++++++++++++++++++++++--- flax/nnx/graphlib.py | 10 +--- flax/nnx/rnglib.py | 89 +++++++++++++++++++++++++-- tests/nnx/graph_utils_test.py | 32 ++++++++++ tests/nnx/rngs_test.py | 86 ++++++++++++++++++++++++++ 6 files changed, 380 insertions(+), 29 deletions(-) diff --git a/docs_nnx/guides/view.ipynb b/docs_nnx/guides/view.ipynb index 4b5378fb7..686ed4d43 100644 --- a/docs_nnx/guides/view.ipynb +++ b/docs_nnx/guides/view.ipynb @@ -6,7 +6,9 @@ "metadata": {}, "source": [ "# Model Views\n", - "This guide covers how to use NNX \"Views\", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes." + "This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes.\n", + "\n", + "NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original." ] }, { @@ -48,7 +50,7 @@ "\n", "Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten.\n", "\n", - "`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below." + "`nnx.with_modules` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below." ] }, { @@ -115,7 +117,7 @@ "source": [ "From the model display, we can see that `Dropout` has `deterministic == False`, suggesting that the model is in training mode. In order to know this, we had to display the model and/or know that `Dropout` is set to training mode by default. It is not clear what state the model is in just by looking at the code without additional inspection. We instead want to be very explicit about what state the model is in. \n", "\n", - "This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below." + "This is where `nnx.with_modules` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below." ] }, { @@ -141,7 +143,7 @@ "id": "5c1ee1db", "metadata": {}, "source": [ - "## Example with `nnx.view`" + "## Example with `nnx.with_modules`" ] }, { @@ -216,7 +218,7 @@ "metadata": {}, "source": [ "## Getting information with `nnx.view_info`\n", - "To see more information about the options for `nnx.view`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions." + "To see more information about the options for `nnx.with_modules`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions." ] }, { @@ -234,9 +236,9 @@ "id": "47479be6", "metadata": {}, "source": [ - "## Writing modules compatible with `nnx.view`\n", + "## Writing modules compatible with `nnx.with_modules`\n", "\n", - "You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.view` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n", + "You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.with_modules` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.with_modules` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n", "\n", "Your `set_view` method should follow these conventions:\n", "\n", @@ -320,14 +322,76 @@ }, { "cell_type": "markdown", - "id": "1acbcc09", + "id": "984b8eca", "metadata": {}, "source": [ "The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n", "\n", + "## Using `with_vars`\n", + "\n", + "{func}`nnx.with_vars ` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n", + "\n", + "The flags it controls are:\n", + "\n", + "- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n", + "- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n", + "- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n", + "\n", + "The `only` argument accepts a {doc}`Filter ` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0938e1f6", + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "class SimpleModel(nnx.Module):\n", + " def __init__(self, rngs):\n", + " self.linear = nnx.Linear(2, 3, rngs=rngs)\n", + "\n", + "model = SimpleModel(nnx.Rngs(0))\n", + "\n", + "# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n", + "ref_model = nnx.with_vars(model, ref=True)\n", + "ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n", + "\n", + "# The original model's kernel is unchanged; ref_model has doubled values\n", + "assert model.linear.kernel is not ref_model.linear.kernel" + ] + }, + { + "cell_type": "markdown", + "id": "67c71d62", + "metadata": {}, + "source": [ + "Use the `only` filter to convert only a subset of Variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "254fe344", + "metadata": {}, + "outputs": [], + "source": [ + "# only convert Param variables, leave BatchStat variables unchanged\n", + "ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)" + ] + }, + { + "cell_type": "markdown", + "id": "1acbcc09", + "metadata": {}, + "source": [ "## Using `with_attributes`\n", "\n", - "If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged." + "If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged." ] }, { @@ -412,7 +476,33 @@ "id": "bf521e45", "metadata": {}, "source": [ - "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`." + "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n", + "\n", + "## Other NNX views\n", + "\n", + "Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n", + "\n", + "- {func}`nnx.as_pure ` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n", + "\n", + " ```python\n", + " _, state = nnx.split(model)\n", + " pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n", + " ```\n", + "\n", + "- {func}`nnx.as_abstract ` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n", + "\n", + " ```python\n", + " with jax.set_mesh(mesh):\n", + " abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n", + " abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n", + " ```\n", + "\n", + "- {func}`nnx.with_rngs ` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n", + "\n", + " ```python\n", + " # Split params stream into 4 keys (one per vmap replica); fork the rest\n", + " vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n", + " ```" ] } ], diff --git a/docs_nnx/guides/view.md b/docs_nnx/guides/view.md index 31ad97ae4..8518d422b 100644 --- a/docs_nnx/guides/view.md +++ b/docs_nnx/guides/view.md @@ -9,7 +9,9 @@ jupytext: --- # Model Views -This guide covers how to use NNX "Views", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes. +This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes. + +NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original. ```{code-cell} from flax import nnx @@ -39,7 +41,7 @@ print(nnx.view_info(model)) Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten. -`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below. +`nnx.with_modules` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below. ```{code-cell} import jax @@ -82,7 +84,7 @@ assert model.do.deterministic == False From the model display, we can see that `Dropout` has `deterministic == False`, suggesting that the model is in training mode. In order to know this, we had to display the model and/or know that `Dropout` is set to training mode by default. It is not clear what state the model is in just by looking at the code without additional inspection. We instead want to be very explicit about what state the model is in. -This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below. +This is where `nnx.with_modules` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below. ```{code-cell} train_model = nnx.with_modules(model, deterministic=False) @@ -96,7 +98,7 @@ assert train_model.do.deterministic is False assert eval_model.do.deterministic is True ``` -## Example with `nnx.view` +## Example with `nnx.with_modules` +++ @@ -143,15 +145,15 @@ plt.show() ``` ## Getting information with `nnx.view_info` -To see more information about the options for `nnx.view`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions. +To see more information about the options for `nnx.with_modules`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions. ```{code-cell} print(nnx.view_info(model)) ``` -## Writing modules compatible with `nnx.view` +## Writing modules compatible with `nnx.with_modules` -You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.view` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about. +You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.with_modules` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.with_modules` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about. Your `set_view` method should follow these conventions: @@ -216,9 +218,47 @@ print(nnx.view_info(model)) The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree. +## Using `with_vars` + +{func}`nnx.with_vars ` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX. + +The flags it controls are: + +- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state. +- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step. +- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform. + +The `only` argument accepts a {doc}`Filter ` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original). + +```{code-cell} +from flax import nnx +import jax +import jax.numpy as jnp + +class SimpleModel(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + +model = SimpleModel(nnx.Rngs(0)) + +# ref=True: expose Variable values as JAX refs so jax.tree.map can update them +ref_model = nnx.with_vars(model, ref=True) +ref_model = jax.tree.map(lambda x: x * 2, ref_model) + +# The original model's kernel is unchanged; ref_model has doubled values +assert model.linear.kernel is not ref_model.linear.kernel +``` + +Use the `only` filter to convert only a subset of Variables: + +```{code-cell} +# only convert Param variables, leave BatchStat variables unchanged +ref_params = nnx.with_vars(model, ref=True, only=nnx.Param) +``` + ## Using `with_attributes` -If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged. +If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged. ```{code-cell} class NoisyLinear(nnx.Module): @@ -280,3 +320,29 @@ print(noisy_model)s ``` Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`. + +## Other NNX views + +Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees: + +- {func}`nnx.as_pure ` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed. + + ```python + _, state = nnx.split(model) + pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain + ``` + +- {func}`nnx.as_abstract ` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes. + + ```python + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0))) + abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars + ``` + +- {func}`nnx.with_rngs ` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys. + + ```python + # Split params stream into 4 keys (one per vmap replica); fork the rest + vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + ``` diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 655e235de..a2837bed3 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -2799,18 +2799,14 @@ def as_pure(tree: A) -> A: inner values. """ - def _pure_fn(x): + def _pure_fn(_, x): if isinstance(x, Variable): - return as_pure(x.get_raw_value()) + return x.get_raw_value() elif variablelib.is_array_ref(x): return x[...] return x - return jax.tree.map( - _pure_fn, - tree, - is_leaf=lambda x: isinstance(x, Variable), - ) + return map(_pure_fn, tree, auto_create_variables=False) def call( diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 6e7d0c9db..4f03b948a 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -26,6 +26,7 @@ from flax.nnx.nn import initializers from flax.nnx.variablelib import Variable from flax.nnx import filterlib +from flax.nnx import graphlib from flax.nnx.pytreelib import Pytree from flax.typing import MISSING, Key, Missing import warnings @@ -464,12 +465,10 @@ def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | t >>> assert new_rngs.dropout.key.shape == () >>> assert new_rngs.noise.key.shape == (2, 5) """ - if isinstance(k, int): - k = {...: k} - elif isinstance(k, tuple): + if isinstance(k, (int, tuple)): k = {...: k} - split_predicates = {filterlib.to_predicate(k): v for k, v in k.items()} + split_predicates = {filterlib.to_predicate(pred): v for pred, v in k.items()} keys: dict[str, RngStream] = {} for name, stream in self.items(): for predicate, num_splits in split_predicates.items(): @@ -725,6 +724,88 @@ def __enter__(self): def __exit__(self, *args): restore_rngs(self) +def with_rngs(tree, split=None, fork=None, graph=False): + """Returns a copy of ``tree`` with ``RngStream`` objects replaced according to + ``split`` and ``fork`` rules. + + ``split`` controls which streams are **split** — after splitting, each call + to the stream produces one key from an array of pre-generated keys rather + than a single key. ``fork`` controls which of the remaining streams are + **forked** — each call to a forked stream produces a unique key derived from + the parent counter. Streams that match neither rule are returned unchanged. + + Args: + tree: A pytree that may contain ``RngStream`` objects (e.g. an ``Rngs`` + instance, a module, or any nested structure). + split: Specifies which streams to split and into what shape. Can be: + + * An ``int`` or ``tuple[int, ...]`` — split *all* streams into this + shape, equivalent to ``{...: split}``. + * A :class:`~flax.nnx.filterlib.Filter`-keyed mapping where each value + is an ``int`` or ``tuple[int, ...]``. The first matching filter wins. + + fork: A :class:`~flax.nnx.filterlib.Filter` selecting which streams not + already handled by ``split`` should be forked. Pass ``...`` to fork all + remaining streams. + graph: If ``True``, uses graph-mode which supports the full + NNX feature set including shared references. If ``False``, uses + tree-mode which treats Modules as regular JAX pytrees, avoiding + the overhead of the graph protocol. + + Returns: + A new tree of the same structure as ``tree`` with ``RngStream`` objects + replaced by split or forked copies as specified. + + Example — split all streams:: + + >>> from flax import nnx + ... + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> new_rngs = nnx.with_rngs(rngs, split=4) + >>> new_rngs.params.key.shape + (4,) + >>> new_rngs.dropout.key.shape + (4,) + + Example — split some streams, fork the rest:: + + >>> rngs = nnx.Rngs(params=0, dropout=1) + >>> new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + >>> new_rngs.params.key.shape + (4,) + >>> new_rngs.dropout.key.shape # forked: scalar key, advanced counter + () + + Example — per-filter split shapes:: + + >>> rngs = nnx.Rngs(params=0, dropout=1, noise=2) + >>> new_rngs = nnx.with_rngs(rngs, split={ + ... 'params': 4, # split params into 4 keys + ... ...: (2, 4), # split anything else into 2×4 keys + ... }) + >>> new_rngs.params.key.shape + (4,) + >>> new_rngs.noise.key.shape + (2, 4) + + """ + if split is None: + split = {} + elif isinstance(split, (int, tuple)): + split = {...: split} + split_predicates = {filterlib.to_predicate(k): v for k, v in split.items()} + fork_predicate = filterlib.to_predicate(fork) + + def f(path, val): + if isinstance(val, RngStream): + for predicate, num_splits in split_predicates.items(): + if predicate(path, val): + return val.split(num_splits) + if fork_predicate(path, val): + return val.fork() + return val + + return graphlib.recursive_map(f, tree, graph=graph) @tp.overload def split_rngs( diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index fe1429695..b633485fb 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1252,6 +1252,38 @@ def test_split_graph_error(self): ): graphdef, state = nnx.split((v, v)) + def test_as_pure_replaces_variables_with_values(self): + model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + pure_model = nnx.as_pure(model) + self.assertNotIsInstance(pure_model.kernel, nnx.Variable) + self.assertNotIsInstance(pure_model.bias, nnx.Variable) + self.assertIsInstance(pure_model.kernel, jax.Array) + self.assertIsInstance(pure_model.bias, jax.Array) + _, state = nnx.split(model) + pure_state = nnx.as_pure(state) + self.assertNotIsInstance(pure_state['kernel'], nnx.Variable) + self.assertNotIsInstance(pure_state['bias'], nnx.Variable) + self.assertIsInstance(pure_state['kernel'], jax.Array) + self.assertIsInstance(pure_state['bias'], jax.Array) + + def test_as_pure_preserves_non_variable_leaves(self): + tree = {'a': jnp.array(1.0), 'b': nnx.Param(jnp.array(2.0)), 'c': 42} + + pure = nnx.as_pure(tree) + + np.testing.assert_array_equal(pure['a'], jnp.array(1.0)) + np.testing.assert_array_equal(pure['b'], jnp.array(2.0)) + self.assertEqual(pure['c'], 42) + + def test_as_pure_nested_variables(self): + # Variable wrapping another Variable — inner value should be unwrapped + inner = nnx.Param(jnp.array(3.0)) + outer = nnx.Param(inner) + + pure = nnx.as_pure(outer) + + np.testing.assert_array_equal(pure, jnp.array(3.0)) + class SimpleModule(nnx.Module): pass diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 9f01582bd..739cd3aa3 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -22,6 +22,7 @@ from flax import nnx from flax import errors +from flax.nnx.rnglib import with_rngs class TestRngs(parameterized.TestCase): @@ -239,5 +240,90 @@ def test_random_helpers(self): ) np.testing.assert_allclose(x_nnx, x_jax) +class TestWithRngs(parameterized.TestCase): + def test_split_int_splits_all_streams(self): + rngs = nnx.Rngs(params=0, dropout=1) + new_rngs = with_rngs(rngs, split=4) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + self.assertEqual(new_rngs['dropout'].key.shape, (4,)) + + def test_split_tuple_splits_all_streams(self): + rngs = nnx.Rngs(params=0, dropout=1) + new_rngs = with_rngs(rngs, split=(2, 3)) + + self.assertEqual(new_rngs.params.key.shape, (2, 3)) + self.assertEqual(new_rngs['dropout'].key.shape, (2, 3)) + + def test_fork_forks_all_streams(self): + rngs = nnx.Rngs(params=0, dropout=1) + original_params_key = rngs.params.key[...] + original_dropout_key = rngs['dropout'].key[...] + + new_rngs = with_rngs(rngs, fork=...) + + # Forked keys are scalar and differ from originals + self.assertEqual(new_rngs.params.key.shape, ()) + self.assertEqual(new_rngs['dropout'].key.shape, ()) + self.assertFalse(jnp.array_equal(new_rngs.params.key[...], original_params_key)) + self.assertFalse(jnp.array_equal(new_rngs['dropout'].key[...], original_dropout_key)) + + def test_split_mapping_applies_per_filter(self): + rngs = nnx.Rngs(params=0, dropout=1, noise=2) + new_rngs = with_rngs(rngs, split={'params': 4, ...: (2, 3)}) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + self.assertEqual(new_rngs['dropout'].key.shape, (2, 3)) + self.assertEqual(new_rngs.noise.key.shape, (2, 3)) + + def test_split_mapping_first_matching_filter_wins(self): + rngs = nnx.Rngs(params=0, dropout=1) + # 'params' filter comes before '...' so it should match first + new_rngs = with_rngs(rngs, split={'params': 4, ...: 8}) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + self.assertEqual(new_rngs['dropout'].key.shape, (8,)) + + def test_split_some_fork_rest(self): + rngs = nnx.Rngs(params=0, dropout=1) + new_rngs = with_rngs(rngs, split={'params': 4}, fork=...) + + self.assertEqual(new_rngs.params.key.shape, (4,)) + # dropout not matched by split → forked (scalar) + self.assertEqual(new_rngs['dropout'].key.shape, ()) + + def test_original_base_key_not_replaced(self): + # with_rngs advances the original stream's counter (consuming one step to + # derive the new keys) but does not replace the original's base key. + rngs = nnx.Rngs(params=0, dropout=1) + original_key_var = rngs.params.key + + with_rngs(rngs, split=4) + + self.assertIs(rngs.params.key, original_key_var) + self.assertEqual(rngs.params.key.shape, ()) + + def test_unmatched_streams_returned_unchanged(self): + rngs = nnx.Rngs(params=0, dropout=1) + # Only fork 'params'; 'dropout' matches neither split nor fork + new_rngs = with_rngs(rngs, fork='params') + + self.assertIsNot(new_rngs['dropout'], rngs['dropout']) # new tree, but... + self.assertTrue(jnp.array_equal(new_rngs['dropout'].key[...], rngs['dropout'].key[...])) + self.assertEqual(new_rngs['dropout'].key.shape, ()) + + def test_works_on_plain_pytree(self): + params_stream = nnx.RngStream(0, tag='params') + dropout_stream = nnx.RngStream(1, tag='dropout') + tree = {'a': params_stream, 'b': dropout_stream} + + new_tree = with_rngs(tree, split=4) + + self.assertEqual(new_tree['a'].key.shape, (4,)) + self.assertEqual(new_tree['b'].key.shape, (4,)) + # Originals unchanged + self.assertEqual(params_stream.key.shape, ()) + + if __name__ == '__main__': absltest.main() From 1e6668e35dd4575940d91a9849b6c6133cdff4b9 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 27 Mar 2026 15:43:36 -0500 Subject: [PATCH 4/5] Add deprecations for old NNX view functions --- flax/nnx/__init__.py | 8 ++--- flax/nnx/deprecations.py | 69 ++++++++++++++++++++++++++++++++++++++++ flax/nnx/graphlib.py | 3 ++ flax/nnx/module.py | 3 ++ flax/nnx/spmd.py | 3 ++ 5 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 flax/nnx/deprecations.py diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 134931825..5a37b098f 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -50,7 +50,7 @@ from .module import M as M from .module import Module as Module from .module import capture as capture -from .module import with_modules as with_modules +from .module import with_modules as with_modules, view as view from .module import view_info as view_info from .module import with_attributes as with_attributes from .module import iter_children as iter_children, iter_modules as iter_modules @@ -75,8 +75,8 @@ from .graphlib import MergeContext as MergeContext from .graphlib import merge_context as merge_context from .graphlib import variables as variables -from .graphlib import with_vars as with_vars -from .graphlib import as_pure as as_pure +from .graphlib import with_vars as with_vars, vars_as as vars_as +from .graphlib import as_pure as as_pure, pure as pure from .graphlib import cached_partial as cached_partial from .graphlib import flatten as flatten from .graphlib import unflatten as unflatten @@ -152,7 +152,7 @@ from .spmd import get_named_sharding as get_named_sharding from .spmd import with_partitioning as with_partitioning from .spmd import get_abstract_model as get_abstract_model -from .spmd import as_abstract as as_abstract +from .spmd import as_abstract as as_abstract, abstract_with_sharding as abstract_with_sharding from .statelib import FlatState as FlatState from .statelib import State as State from .statelib import to_flat_state as to_flat_state diff --git a/flax/nnx/deprecations.py b/flax/nnx/deprecations.py new file mode 100644 index 000000000..e1e9c8236 --- /dev/null +++ b/flax/nnx/deprecations.py @@ -0,0 +1,69 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import warnings +from typing import TypeVar +from collections.abc import Callable + +F = TypeVar('F', bound=Callable) + + +def deprecated(new_fn: F) -> F: + """Creates a deprecated alias for a renamed function. + + .. deprecated:: + This decorator is for marking functions as deprecated. The returned + wrapper emits a :class:`DeprecationWarning` on every call and then + delegates to ``new_fn``. + + The returned callable copies the signature, type annotations, and + docstring of ``new_fn``, with a deprecation notice prepended to the + docstring. This keeps IDE autocomplete and type-checking working while + clearly communicating that callers should migrate. + + Args: + new_fn: The current, non-deprecated function to delegate to. + + Returns: + A wrapper that emits a :class:`DeprecationWarning` and then calls + ``new_fn`` with the same arguments. + + Example:: + + >>> from flax.nnx.deprecations import deprecated + >>> def new_api(x): + ... return x * 2 + >>> old_api = deprecated(new_api) + >>> old_api(3) # emits DeprecationWarning: use new_api instead + 6 + + """ + + @functools.wraps(new_fn) + def wrapper(*args, **kwargs): + warnings.warn( + f'This function is deprecated. Use {new_fn.__qualname__} instead.', + DeprecationWarning, + stacklevel=2, + ) + return new_fn(*args, **kwargs) + + dep_notice = ( + f'.. deprecated::\n' + f' Use :func:`{new_fn.__qualname__}` instead.\n\n' + ) + wrapper.__doc__ = dep_notice + (new_fn.__doc__ or '') + + return wrapper # type: ignore[return-value] diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a2837bed3..b4f0a1646 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -26,6 +26,7 @@ from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib from flax.nnx import statelib +from flax.nnx.deprecations import deprecated from flax.nnx.proxy_caller import ( ApplyCaller, CallableProxy, @@ -2761,6 +2762,7 @@ def _to_refs(path, x): node = graphlib.map(_to_refs, node, graph=graph) return node +vars_as = deprecated(with_vars) def as_pure(tree: A) -> A: """Returns a new tree with all ``Variable`` objects replaced with inner values. @@ -2808,6 +2810,7 @@ def _pure_fn(_, x): return map(_pure_fn, tree, auto_create_variables=False) +pure = deprecated(as_pure) def call( graphdef_state: tuple[GraphDef[A], GraphState], / diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 5dc30133b..4f0627bc1 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -29,6 +29,7 @@ from flax.nnx.pytreelib import Pytree, PytreeMeta from flax.nnx.graphlib import GraphState from flax.nnx.statelib import split_state, State +from flax.nnx.deprecations import deprecated import functools as ft from flax.typing import Key, Path, PathParts from collections.abc import MutableMapping @@ -515,6 +516,8 @@ def _set_mode_fn(path, node): return out +view = deprecated(with_modules) + def with_attributes( node: A, /, diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 56c07b15d..e3a8d34f1 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -16,6 +16,7 @@ import flax.core.spmd as core_spmd from flax.nnx import variablelib, graphlib +from flax.nnx.deprecations import deprecated from flax.nnx.transforms.transforms import eval_shape from flax.typing import ( Sharding, @@ -238,3 +239,5 @@ def add_sharding(_path, x): return abs_var return x return graphlib.map(add_sharding, tree, graph=graph) + +abstract_with_sharding = deprecated(as_abstract) From c5f457335a607b99543b262e1b60b233b8523bfb Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 27 Mar 2026 17:55:44 -0500 Subject: [PATCH 5/5] Fix _to_nested_state bug --- flax/nnx/graphlib.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index b4f0a1646..dbf9a62db 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -2573,7 +2573,10 @@ def map( A :class:`State` with the mapped values. """ graphdef, state = split(node, graph=graph) - state = statelib.map_state(f, state) + if isinstance(state, statelib.State): + state = statelib.map_state(f, state) + else: + state = f((), state) return merge(graphdef, state, auto_create_variables=auto_create_variables)