diff --git a/docs_nnx/guides/checkpointing.ipynb b/docs_nnx/guides/checkpointing.ipynb
index 243fc404a..f368efccf 100644
--- a/docs_nnx/guides/checkpointing.ipynb
+++ b/docs_nnx/guides/checkpointing.ipynb
@@ -35,7 +35,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -44,13 +44,14 @@
"import jax\n",
"from jax import numpy as jnp\n",
"import numpy as np\n",
+ "import \n",
"\n",
"ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -82,13 +83,13 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -100,7 +101,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -223,31 +224,67 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- " The abstract NNX state (all leaves are abstract arrays):\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "\n",
- "\n",
- " NNX State restored: \n",
- "\n",
- "\n",
- " /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
- " warnings.warn(\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "\n",
- "\n",
+ "Note that due to a bug in the Orbax library, this approach will not work if your state contains `nnx.Rngs` objects. Instead, you must restore the checkpoint to a pure dictionary first. You can safely ignore the scary warning that comes with this process."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Save the checkpoint as usual\n",
+ "model_with_rng = nnx.Dropout(0.5, rngs=nnx.Rngs(0))\n",
+ "graphdef_with_rng, state_with_rng = nnx.split(model_with_rng)\n",
+ "checkpointer.save(ckpt_dir / 'rand_state', state)\n",
+ "\n",
+ "# But restore it like this\n",
+ "restored_pure_dict = checkpointer.restore(ckpt_dir / 'rand_state')\n",
+ "model_with_rng = nnx.merge(graphdef_with_rng, restored_pure_dict)\n",
+ "assert model_with_rng(x).shape == (3, 4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that if you saved a checkpoint with flax version <13 and wish to restore it with a newer version of flax, you will need to process the abstract state slightly before merging it. Specifically, due to change in how parameters are named, paths ending in `key` will need to be changed to end in `base_key`. You can do this using ` jax.tree.leaves_with_path` and its inverse `flax.jax_utils.build_tree_from_paths` as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def transform_path(path):\n",
+ " if isinstance(path[-1], jax.tree_util.DictKey) and path[-1].key == \"key\":\n",
+ " return (*path[:-1], \"base_key\")\n",
+ " else:\n",
+ " return path\n",
+ "\n",
+ "if flax.__version__ >= \"0.13.0\":\n",
+ " restored_pure_dict = flax.jax_utils.build_tree_from_paths([\n",
+ " (transform_path(path), leaf) for path, leaf in jax.tree.leaves_with_path(restored_pure_dict)\n",
+ " ])\n",
+ "model = nnx.merge(graphdef, restored_pure_dict)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
"## Save and restore as pure dictionaries\n",
"\n",
- "When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries."
+ "To use Python built-in container types instead of flax `State` objects, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries."
]
},
{
@@ -306,16 +343,6 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "
\n",
- "\n",
- "\n",
- "\n",
- "
\n",
- "\n",
- "\n",
- " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n",
- "\n",
- "\n",
"## Restore when checkpoint structures differ\n",
"\n",
"The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below.\n",
@@ -440,9 +467,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.13.7"
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
}
diff --git a/docs_nnx/guides/checkpointing.md b/docs_nnx/guides/checkpointing.md
index 3cd828bb1..66227d1c6 100644
--- a/docs_nnx/guides/checkpointing.md
+++ b/docs_nnx/guides/checkpointing.md
@@ -39,6 +39,7 @@ import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np
+import
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
```
@@ -105,31 +106,39 @@ model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)
```
- The abstract NNX state (all leaves are abstract arrays):
-
-
-
-
-
-
- NNX State restored:
-
-
- /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
- warnings.warn(
-
-
-
-
-
+Note that due to a bug in the Orbax library, this approach will not work if your state contains `nnx.Rngs` objects. Instead, you must restore the checkpoint to a pure dictionary first. You can safely ignore the scary warning that comes with this process.
+```{code-cell} ipython3
+# Save the checkpoint as usual
+model_with_rng = nnx.Dropout(0.5, rngs=nnx.Rngs(0))
+graphdef_with_rng, state_with_rng = nnx.split(model_with_rng)
+checkpointer.save(ckpt_dir / 'rand_state', state)
+
+# But restore it like this
+restored_pure_dict = checkpointer.restore(ckpt_dir / 'rand_state')
+model_with_rng = nnx.merge(graphdef_with_rng, restored_pure_dict)
+assert model_with_rng(x).shape == (3, 4)
+```
-
+Note that if you saved a checkpoint with flax version <13 and wish to restore it with a newer version of flax, you will need to process the abstract state slightly before merging it. Specifically, due to change in how parameters are named, paths ending in `key` will need to be changed to end in `base_key`. You can do this using ` jax.tree.leaves_with_path` and its inverse `flax.jax_utils.build_tree_from_paths` as follows:
+```{code-cell} ipython3
+def transform_path(path):
+ if isinstance(path[-1], jax.tree_util.DictKey) and path[-1].key == "key":
+ return (*path[:-1], "base_key")
+ else:
+ return path
+
+if flax.__version__ >= "0.13.0":
+ restored_pure_dict = flax.jax_utils.build_tree_from_paths([
+ (transform_path(path), leaf) for path, leaf in jax.tree.leaves_with_path(restored_pure_dict)
+ ])
+model = nnx.merge(graphdef, restored_pure_dict)
+```
## Save and restore as pure dictionaries
-When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries.
+To use Python built-in container types instead of flax `State` objects, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries.
```{code-cell} ipython3
# Save as pure dict
@@ -146,16 +155,6 @@ model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
```
-
-
-
-
-
-
-
- WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
-
-
## Restore when checkpoint structures differ
The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below.
diff --git a/flax/jax_utils.py b/flax/jax_utils.py
index bfe6849f3..6aae031a6 100644
--- a/flax/jax_utils.py
+++ b/flax/jax_utils.py
@@ -25,6 +25,7 @@
from jax import core, lax
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
+from typing import Any
def _pmap_device_order():
@@ -316,3 +317,41 @@ def unpad(x):
return out if static_return else jax.tree_util.tree_map(unpad, out)
return pad_shard_unpad_wrapper
+
+
+class _DictOrList(dict):
+ """Dictionary that should be converted to a list."""
+ is_list: bool = False
+
+def _to_pytree(a):
+ if not isinstance(a, _DictOrList):
+ return a
+ if a.is_list:
+ return [_to_pytree(v) for k, v in sorted(a.items())]
+ else:
+ return {k: _to_pytree(v) for k, v in a.items()}
+
+def _path_ix(a):
+ return a.key if isinstance(a, jax.tree_util.DictKey) else a.idx
+
+def build_tree_from_paths(paths_and_leaves: list[tuple[jax.tree_util.KeyPath, Any]]):
+ """
+ Inverse of ``jax.tree.leaves_with_path``. Builds a PyTree from a list of (path, leaf) pairs.
+ """
+ root = _DictOrList()
+ for path, leaf in paths_and_leaves:
+ if not path: continue
+ current = root
+
+ # Navigate/create structure following the path
+ for key_entry in path[:-1]:
+ k = _path_ix(key_entry)
+ if k not in current:
+ current[k] = _DictOrList()
+ current.is_list = isinstance(key_entry, jax.tree_util.SequenceKey)
+ current = current[k]
+
+ # Set the leaf value
+ current.is_list = isinstance(path[-1], jax.tree_util.SequenceKey)
+ current[_path_ix(path[-1])] = leaf
+ return _to_pytree(root)
diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py
index f67fb440c..573df4020 100644
--- a/flax/nnx/rnglib.py
+++ b/flax/nnx/rnglib.py
@@ -112,15 +112,21 @@ def __init__(
count = jnp.zeros(key.shape, dtype=jnp.uint32)
self.tag = tag
- self.key = RngKey(key, tag=tag)
+ self.base_key = RngKey(key, tag=tag)
self.count = RngCount(count, tag=tag)
def __call__(self) -> jax.Array:
self.count._check_can_update()
- key = random.fold_in(self.key[...], self.count[...])
+ key = random.fold_in(self.base_key[...], self.count[...])
self.count[...] += 1
return key
+ def key(self) -> jax.Array:
+ return self()
+
+ def split(self, k: int):
+ return self.fork(split=k)
+
def fork(self, *, split: int | tuple[int, ...] | None = None):
key = self()
if split is not None:
@@ -316,7 +322,7 @@ class Rngs(Pytree):
``counter``. Every time a key is requested, the counter is incremented and the key is
generated from the seed key and the counter by using ``jax.random.fold_in``.
- To create an ``Rngs`` pass in an integer or ``jax.random.key`` to the
+ To create an ``Rngs`` pass in an integer or ``jax.random.base_key`` to the
constructor as a keyword argument with the name of the stream. The key will be used as the
starting seed for the stream, and the counter will be initialized to zero. Then call the
stream to get a key::
@@ -369,7 +375,7 @@ def __init__(
Args:
default: the starting seed for the ``default`` stream, defaults to None.
**rngs: keyword arguments specifying the starting seed for each stream.
- The key can be an integer or a ``jax.random.key``.
+ The key can be an integer or a ``jax.random.base_key``.
"""
if default is not None:
if isinstance(default, tp.Mapping):
@@ -379,7 +385,7 @@ def __init__(
for tag, key in rngs.items():
if isinstance(key, RngStream):
- key = key.key.get_value()
+ key = key.base_key.get_value()
stream = RngStream(
key=key,
tag=tag,
@@ -406,6 +412,9 @@ def __getattr__(self, name: str):
def __call__(self):
return self.default()
+ def key(self):
+ return self.default()
+
def __iter__(self) -> tp.Iterator[str]:
for name, stream in vars(self).items():
if isinstance(stream, RngStream):
@@ -424,6 +433,9 @@ def items(self):
if isinstance(stream, RngStream):
yield name, stream
+ def split(self, splits: int):
+ return self.fork(split=splits)
+
def fork(
self,
/,
@@ -448,8 +460,8 @@ def fork(
>>> rngs = nnx.Rngs(params=1, dropout=2)
>>> new_rngs = rngs.fork(split=5)
...
- >>> assert new_rngs.params.key.shape == (5,)
- >>> assert new_rngs.dropout.key.shape == (5,)
+ >>> assert new_rngs.params.base_key.shape == (5,)
+ >>> assert new_rngs.dropout.base_key.shape == (5,)
``split`` also accepts a mapping of
`Filters `__ to
@@ -462,9 +474,9 @@ def fork(
... ...: (2, 5), # split anything else into 2x5 keys
... })
...
- >>> assert new_rngs.params.key.shape == (5,)
- >>> assert new_rngs.dropout.key.shape == ()
- >>> assert new_rngs.noise.key.shape == (2, 5)
+ >>> assert new_rngs.params.base_key.shape == (5,)
+ >>> assert new_rngs.dropout.base_key.shape == ()
+ >>> assert new_rngs.noise.base_key.shape == (2, 5)
"""
if split is None:
split = {}
@@ -725,18 +737,18 @@ def split_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5)
- >>> rngs.params.key.shape, rngs.dropout.key.shape
+ >>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), (5,))
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=(2, 5))
- >>> rngs.params.key.shape, rngs.dropout.key.shape
+ >>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((2, 5), (2, 5))
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
- >>> rngs.params.key.shape, rngs.dropout.key.shape
+ >>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), ())
Once split, random state can be used with transforms like :func:`nnx.vmap`::
@@ -756,7 +768,7 @@ def split_rngs(
... return Model(rngs)
...
>>> model = create_model(rngs)
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
``split_rngs`` returns a SplitBackups object that can be used to restore the
@@ -769,7 +781,7 @@ def split_rngs(
>>> model = create_model(rngs)
>>> nnx.restore_rngs(backups)
...
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
SplitBackups can also be used as a context manager to automatically restore
@@ -780,7 +792,7 @@ def split_rngs(
>>> with nnx.split_rngs(rngs, splits=5, only='params'):
... model = create_model(rngs)
...
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -792,7 +804,7 @@ def split_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = create_model(rngs)
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
@@ -819,18 +831,18 @@ def split_rngs_wrapper(*args, **kwargs):
for path, stream in graph.iter_graph(node):
if (
isinstance(stream, RngStream)
- and predicate((*path, 'key'), stream.key)
+ and predicate((*path, 'key'), stream.base_key)
and predicate((*path, 'count'), stream.count)
):
key = stream()
- backups.append((stream, stream.key.raw_value, stream.count.raw_value))
+ backups.append((stream, stream.base_key.raw_value, stream.count.raw_value))
key = random.split(key, splits)
if squeeze:
key = key[0]
- if variablelib.is_array_ref(stream.key.raw_value):
- stream.key.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
+ if variablelib.is_array_ref(stream.base_key.raw_value):
+ stream.base_key.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
else:
- stream.key.value = key
+ stream.base_key.value = key
if squeeze:
counts_shape = stream.count.shape
elif isinstance(splits, int):
@@ -889,18 +901,18 @@ def fork_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split=5)
- >>> rngs.params.key.shape, rngs.dropout.key.shape
+ >>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), (5,))
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split=(2, 5))
- >>> rngs.params.key.shape, rngs.dropout.key.shape
+ >>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((2, 5), (2, 5))
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs, split={'params': 5})
- >>> rngs.params.key.shape, rngs.dropout.key.shape
+ >>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
((5,), ())
Once forked, random state can be used with transforms like :func:`nnx.vmap`::
@@ -920,7 +932,7 @@ def fork_rngs(
... return Model(rngs)
...
>>> model = create_model(rngs)
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
``fork_rngs`` returns a SplitBackups object that can be used to restore the
@@ -933,7 +945,7 @@ def fork_rngs(
>>> model = create_model(rngs)
>>> nnx.restore_rngs(backups)
...
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
SplitBackups can also be used as a context manager to automatically restore
@@ -944,7 +956,7 @@ def fork_rngs(
>>> with nnx.fork_rngs(rngs, split={'params': 5}):
... model = create_model(rngs)
...
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -956,7 +968,7 @@ def fork_rngs(
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = create_model(rngs)
- >>> model.dropout.rngs.key.shape
+ >>> model.dropout.rngs.base_key.shape
()
"""
if isinstance(node, Missing):
@@ -984,14 +996,14 @@ def fork_rngs_wrapper(*args, **kwargs):
for predicate, splits in predicate_splits.items():
if (
isinstance(stream, RngStream)
- and predicate((*path, 'key'), stream.key)
+ and predicate((*path, 'key'), stream.base_key)
and predicate((*path, 'count'), stream.count)
):
forked_stream = stream.fork(split=splits)
# backup the original stream state
- backups.append((stream, stream.key.raw_value, stream.count.raw_value))
+ backups.append((stream, stream.base_key.raw_value, stream.count.raw_value))
# apply the forked key and count to the original stream
- stream.key.raw_value = forked_stream.key.raw_value
+ stream.base_key.raw_value = forked_stream.base_key.raw_value
stream.count.raw_value = forked_stream.count.raw_value
return SplitBackups(backups)
@@ -1001,7 +1013,7 @@ def backup_keys(node: tp.Any, /):
backups: list[StreamBackup] = []
for _, stream in graph.iter_graph(node):
if isinstance(stream, RngStream):
- backups.append((stream, stream.key.raw_value))
+ backups.append((stream, stream.base_key.raw_value))
return backups
def _scalars_only(
@@ -1046,7 +1058,7 @@ def reseed(
of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to
define a custom reseeding policy.
**stream_keys: a mapping of stream names to new keys. The keys can be
- either integers or ``jax.random.key``.
+ either integers or ``jax.random.base_key``.
Example::
@@ -1084,16 +1096,16 @@ def reseed(
rngs = Rngs(**stream_keys)
for path, stream in graph.iter_graph(node):
if isinstance(stream, RngStream):
- if stream.key.tag in stream_keys:
- key = rngs[stream.key.tag]()
- key = policy(path, key, stream.key.shape)
- stream.key.value = key
+ if stream.base_key.tag in stream_keys:
+ key = rngs[stream.base_key.tag]()
+ key = policy(path, key, stream.base_key.shape)
+ stream.base_key.value = key
stream.count.value = jnp.zeros(key.shape, dtype=jnp.uint32)
def restore_rngs(backups: tp.Iterable[StreamBackup], /):
for backup in backups:
stream = backup[0]
- stream.key.raw_value = backup[1]
+ stream.base_key.raw_value = backup[1]
if len(backup) == 3:
stream.count.raw_value = backup[2] # count
diff --git a/pyproject.toml b/pyproject.toml
index ebf5dcd90..a9a94dadf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -237,3 +237,6 @@ quote-style = "single"
[tool.uv]
# Ignore uv.lock and always upgrade the package to the latest
upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"]
+
+[tool.uv.sources]
+jax = { path = "../jax" }
diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py
index 45a658e3e..7a5a5e9c9 100644
--- a/tests/nnx/bridge/module_test.py
+++ b/tests/nnx/bridge/module_test.py
@@ -149,7 +149,7 @@ def __call__(self):
scope = bar.apply({}, rngs=1)
self.assertIsNone(bar.scope)
- self.assertEqual(scope.rngs.default.key[...], jax.random.key(1))
+ self.assertEqual(scope.rngs.default.base_key[...], jax.random.key(1))
self.assertEqual(scope.rngs.default.count[...], 0)
class Baz(bridge.Module):
diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py
index 5866f5879..b218a6228 100644
--- a/tests/nnx/module_test.py
+++ b/tests/nnx/module_test.py
@@ -557,7 +557,7 @@ def test_create_abstract(self):
def test_create_abstract_stateful(self):
linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0)))
- assert linear.rngs.key.value == jax.ShapeDtypeStruct(
+ assert linear.rngs.base_key.value == jax.ShapeDtypeStruct(
(), jax.random.key(0).dtype
)
diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py
index 823539e38..232a6b300 100644
--- a/tests/nnx/mutable_array_test.py
+++ b/tests/nnx/mutable_array_test.py
@@ -623,7 +623,7 @@ def test_rngs_create(self):
paths[0],
(
jax.tree_util.GetAttrKey('default'),
- jax.tree_util.GetAttrKey('count'),
+ jax.tree_util.GetAttrKey('base_key'),
jax.tree_util.GetAttrKey('value'),
),
)
@@ -631,7 +631,7 @@ def test_rngs_create(self):
paths[1],
(
jax.tree_util.GetAttrKey('default'),
- jax.tree_util.GetAttrKey('key'),
+ jax.tree_util.GetAttrKey('count'),
jax.tree_util.GetAttrKey('value'),
),
)
diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py
index c8a9d55a7..69ffc0f87 100644
--- a/tests/nnx/nn/attention_test.py
+++ b/tests/nnx/nn/attention_test.py
@@ -128,7 +128,7 @@ def test_keep_rngs(self, keep_rngs):
if keep_rngs:
_, _, nondiff = nnx.split(module, nnx.Param, ...)
assert isinstance(nondiff['rngs']['count'], nnx.RngCount)
- assert isinstance(nondiff['rngs']['key'], nnx.RngKey)
+ assert isinstance(nondiff['rngs']['base_key'], nnx.RngKey)
else:
nnx.split(module, nnx.Param)
diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py
index fc8efba20..6b3da070c 100644
--- a/tests/nnx/rngs_test.py
+++ b/tests/nnx/rngs_test.py
@@ -45,12 +45,12 @@ def test_rng_stream(self):
key1 = rngs.params()
self.assertEqual(rngs.params.count[...], 1)
- self.assertIs(rngs.params.key[...], key0)
+ self.assertIs(rngs.params.base_key[...], key0)
self.assertFalse(jnp.allclose(key0, key1))
key2 = rngs.params()
self.assertEqual(rngs.params.count[...], 2)
- self.assertIs(rngs.params.key[...], key0)
+ self.assertIs(rngs.params.base_key[...], key0)
self.assertFalse(jnp.allclose(key1, key2))
def test_rng_trace_level_constraints(self):
diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py
index 3b13dfe09..48769111f 100644
--- a/tests/nnx/transforms_test.py
+++ b/tests/nnx/transforms_test.py
@@ -2137,14 +2137,14 @@ def create_block(rngs: nnx.Rngs):
return Block(rngs)
rngs = nnx.Rngs(0)
- initial_key = rngs.default.key[...]
+ initial_key = rngs.default.base_key[...]
backups = nnx.split_rngs(rngs, splits=5)
module = create_block(rngs)
nnx.restore_rngs(backups)
assert rngs.default.count[...] == 1
- assert rngs.default.key[...] == initial_key
+ assert rngs.default.base_key[...] == initial_key
assert not jnp.allclose(
module.linear.kernel[0],
module.linear.kernel[1],
@@ -2166,7 +2166,7 @@ def forward_block(module, x):
assert y.shape == (5, 1, 3)
assert rngs.default.count[...] == 2
- assert rngs.default.key[...] == initial_key
+ assert rngs.default.base_key[...] == initial_key
y2 = forward_block(module, x)
@@ -2191,12 +2191,12 @@ def create_block(rngs: nnx.Rngs):
return Block(rngs)
rngs = nnx.Rngs(0)
- initial_key = rngs.default.key[...]
+ initial_key = rngs.default.base_key[...]
module = create_block(rngs.fork(split=5))
assert rngs.default.count[...] == 1
- assert rngs.default.key[...] == initial_key
+ assert rngs.default.base_key[...] == initial_key
assert not jnp.allclose(
module.linear.kernel[0],
module.linear.kernel[1],
@@ -2213,7 +2213,7 @@ def forward_block(module, x):
y = forward_block(module, x)
assert y.shape == (5, 1, 3)
- assert rngs.default.key[...] == initial_key
+ assert rngs.default.base_key[...] == initial_key
y2 = forward_block(module, x)
@@ -2239,12 +2239,12 @@ def create_block(rngs: nnx.Rngs):
return Block(rngs)
rngs = nnx.Rngs(0)
- initial_key = rngs.default.key[...]
+ initial_key = rngs.default.base_key[...]
module = create_block(rngs)
assert rngs.default.count[...] == 1
- assert rngs.default.key[...] == initial_key
+ assert rngs.default.base_key[...] == initial_key
assert not jnp.allclose(
module.linear.kernel[0],
module.linear.kernel[1],
@@ -2262,7 +2262,7 @@ def forward_block(module, x):
y = forward_block(module, x)
assert y.shape == (5, 1, 3)
- assert rngs.default.key[...] == initial_key
+ assert rngs.default.base_key[...] == initial_key
y2 = forward_block(module, x)
@@ -2327,19 +2327,19 @@ def create_block(rngs: nnx.Rngs):
assert module.bn.scale.shape == (3,)
assert module.bn.mean.shape == (5, 3)
assert module.dropout.rngs is not None
- self.assertEqual(module.dropout.rngs.key.shape, (5,))
+ self.assertEqual(module.dropout.rngs.base_key.shape, (5,))
@nnx.vmap(in_axes=(state_axes, 0), out_axes=0)
def forward_block(module: Block, x):
assert module.dropout.rngs is not None
- self.assertEqual(module.dropout.rngs.key.shape, ())
+ self.assertEqual(module.dropout.rngs.base_key.shape, ())
return module(x)
x = jnp.ones((5, 1, 2))
y = forward_block(module, x)
assert module.dropout.rngs is not None
- self.assertEqual(module.dropout.rngs.key.shape, (5,))
+ self.assertEqual(module.dropout.rngs.base_key.shape, (5,))
assert y.shape == (5, 1, 3)
def test_state_axes_super_simple(self):
@@ -2398,7 +2398,7 @@ def forward_block(module: Block, x):
rngs = nnx.Rngs(0)
module = create_block(rngs)
- initial_key = module.dropout.rngs.key[...]
+ initial_key = module.dropout.rngs.base_key[...]
assert module.dropout.rngs.count[...] == 0
assert module.linear.kernel.shape == (din, dout)
@@ -2418,7 +2418,7 @@ def forward_block(module: Block, x):
# dropout is working!
assert not jnp.allclose(y, y2)
- assert module.dropout.rngs.key[...] == initial_key
+ assert module.dropout.rngs.base_key[...] == initial_key
def test_consistent_aliasing_inputs(self):
class Foo(nnx.Module):
@@ -2713,7 +2713,7 @@ def create_block(rngs: nnx.Rngs):
rngs = nnx.Rngs(0)
module = create_block(rngs)
- initial_key = module.dropout.rngs.key[...]
+ initial_key = module.dropout.rngs.base_key[...]
assert module.dropout.rngs.count[0] == 0
assert module.linear.kernel.shape == (1, 3, 10)
@@ -2729,7 +2729,7 @@ def forward_block(module, x):
assert y.shape == (1, 1, 10)
assert module.dropout.rngs.count[0] == 1
- assert module.dropout.rngs.key[...] == initial_key
+ assert module.dropout.rngs.base_key[...] == initial_key
y2 = forward_block(module, x)
@@ -2796,7 +2796,7 @@ def forward_block(module: Block, x):
rngs = nnx.Rngs(0)
module = create_block(rngs)
- initial_key = module.dropout.rngs.key[...]
+ initial_key = module.dropout.rngs.base_key[...]
assert module.dropout.rngs.count[...] == 0
assert module.linear.kernel.shape == (din, dout)
@@ -2814,7 +2814,7 @@ def forward_block(module: Block, x):
# dropout is working!
assert not jnp.allclose(y, y2)
- assert module.dropout.rngs.key[...] == initial_key
+ assert module.dropout.rngs.base_key[...] == initial_key
class TestCond(absltest.TestCase):