Skip to content

Commit 6656042

Browse files
committed
Add split and key to nnx.Rngs
1 parent 32d5374 commit 6656042

File tree

7 files changed

+69
-63
lines changed

7 files changed

+69
-63
lines changed

flax/nnx/rnglib.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ def __init__(
112112

113113
count = jnp.zeros(key.shape, dtype=jnp.uint32)
114114
self.tag = tag
115-
self.key_ = RngKey(key, tag=tag)
115+
self.base_key = RngKey(key, tag=tag)
116116
self.count = RngCount(count, tag=tag)
117117

118118
def __call__(self) -> jax.Array:
119119
if not self.count.has_ref and not self.count._trace_state.is_valid():
120120
raise errors.TraceContextError(
121121
f'Cannot mutate {type(self).__name__} from a different trace level'
122122
)
123-
key = random.fold_in(self.key_[...], self.count[...])
123+
key = random.fold_in(self.base_key[...], self.count[...])
124124
self.count[...] += 1
125125
return key
126126

@@ -325,7 +325,7 @@ class Rngs(Pytree):
325325
``counter``. Every time a key is requested, the counter is incremented and the key is
326326
generated from the seed key and the counter by using ``jax.random.fold_in``.
327327
328-
To create an ``Rngs`` pass in an integer or ``jax.random.key_`` to the
328+
To create an ``Rngs`` pass in an integer or ``jax.random.base_key`` to the
329329
constructor as a keyword argument with the name of the stream. The key will be used as the
330330
starting seed for the stream, and the counter will be initialized to zero. Then call the
331331
stream to get a key::
@@ -378,7 +378,7 @@ def __init__(
378378
Args:
379379
default: the starting seed for the ``default`` stream, defaults to None.
380380
**rngs: keyword arguments specifying the starting seed for each stream.
381-
The key can be an integer or a ``jax.random.key_``.
381+
The key can be an integer or a ``jax.random.base_key``.
382382
"""
383383
if default is not None:
384384
if isinstance(default, tp.Mapping):
@@ -388,7 +388,7 @@ def __init__(
388388

389389
for tag, key in rngs.items():
390390
if isinstance(key, RngStream):
391-
key = key.key_[...]
391+
key = key.base_key[...]
392392
stream = RngStream(
393393
key=key,
394394
tag=tag,
@@ -415,6 +415,9 @@ def __getattr__(self, name: str):
415415
def __call__(self):
416416
return self.default()
417417

418+
def key(self):
419+
return self.default()
420+
418421
def __iter__(self) -> tp.Iterator[str]:
419422
for name, stream in vars(self).items():
420423
if isinstance(stream, RngStream):
@@ -433,6 +436,9 @@ def items(self):
433436
if isinstance(stream, RngStream):
434437
yield name, stream
435438

439+
def split(self, splits: int):
440+
return self.fork(split=splits)
441+
436442
def fork(
437443
self,
438444
/,
@@ -457,8 +463,8 @@ def fork(
457463
>>> rngs = nnx.Rngs(params=1, dropout=2)
458464
>>> new_rngs = rngs.fork(split=5)
459465
...
460-
>>> assert new_rngs.params.key_.shape == (5,)
461-
>>> assert new_rngs.dropout.key_.shape == (5,)
466+
>>> assert new_rngs.params.base_key.shape == (5,)
467+
>>> assert new_rngs.dropout.base_key.shape == (5,)
462468
463469
``split`` also accepts a mapping of
464470
`Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to
@@ -471,9 +477,9 @@ def fork(
471477
... ...: (2, 5), # split anything else into 2x5 keys
472478
... })
473479
...
474-
>>> assert new_rngs.params.key_.shape == (5,)
475-
>>> assert new_rngs.dropout.key_.shape == ()
476-
>>> assert new_rngs.noise.key_.shape == (2, 5)
480+
>>> assert new_rngs.params.base_key.shape == (5,)
481+
>>> assert new_rngs.dropout.base_key.shape == ()
482+
>>> assert new_rngs.noise.base_key.shape == (2, 5)
477483
"""
478484
if split is None:
479485
split = {}
@@ -734,18 +740,18 @@ def split_rngs(
734740
...
735741
>>> rngs = nnx.Rngs(params=0, dropout=1)
736742
>>> _ = nnx.split_rngs(rngs, splits=5)
737-
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
743+
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
738744
((5,), (5,))
739745
740746
>>> rngs = nnx.Rngs(params=0, dropout=1)
741747
>>> _ = nnx.split_rngs(rngs, splits=(2, 5))
742-
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
748+
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
743749
((2, 5), (2, 5))
744750
745751
746752
>>> rngs = nnx.Rngs(params=0, dropout=1)
747753
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
748-
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
754+
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
749755
((5,), ())
750756
751757
Once split, random state can be used with transforms like :func:`nnx.vmap`::
@@ -765,7 +771,7 @@ def split_rngs(
765771
... return Model(rngs)
766772
...
767773
>>> model = create_model(rngs)
768-
>>> model.dropout.rngs.key_.shape
774+
>>> model.dropout.rngs.base_key.shape
769775
()
770776
771777
``split_rngs`` returns a SplitBackups object that can be used to restore the
@@ -778,7 +784,7 @@ def split_rngs(
778784
>>> model = create_model(rngs)
779785
>>> nnx.restore_rngs(backups)
780786
...
781-
>>> model.dropout.rngs.key_.shape
787+
>>> model.dropout.rngs.base_key.shape
782788
()
783789
784790
SplitBackups can also be used as a context manager to automatically restore
@@ -789,7 +795,7 @@ def split_rngs(
789795
>>> with nnx.split_rngs(rngs, splits=5, only='params'):
790796
... model = create_model(rngs)
791797
...
792-
>>> model.dropout.rngs.key_.shape
798+
>>> model.dropout.rngs.base_key.shape
793799
()
794800
795801
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -801,7 +807,7 @@ def split_rngs(
801807
...
802808
>>> rngs = nnx.Rngs(params=0, dropout=1)
803809
>>> model = create_model(rngs)
804-
>>> model.dropout.rngs.key_.shape
810+
>>> model.dropout.rngs.base_key.shape
805811
()
806812
807813
@@ -828,18 +834,18 @@ def split_rngs_wrapper(*args, **kwargs):
828834
for path, stream in graph.iter_graph(node):
829835
if (
830836
isinstance(stream, RngStream)
831-
and predicate((*path, 'key'), stream.key_)
837+
and predicate((*path, 'key'), stream.base_key)
832838
and predicate((*path, 'count'), stream.count)
833839
):
834840
key = stream()
835-
backups.append((stream, stream.key_.raw_value, stream.count.raw_value))
841+
backups.append((stream, stream.base_key.raw_value, stream.count.raw_value))
836842
key = random.split(key, splits)
837843
if squeeze:
838844
key = key[0]
839-
if variablelib.is_array_ref(stream.key_.raw_value):
840-
stream.key_.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
845+
if variablelib.is_array_ref(stream.base_key.raw_value):
846+
stream.base_key.raw_value = variablelib.new_ref(key) # type: ignore[assignment]
841847
else:
842-
stream.key_.value = key
848+
stream.base_key.value = key
843849
if squeeze:
844850
counts_shape = stream.count.shape
845851
elif isinstance(splits, int):
@@ -898,18 +904,18 @@ def fork_rngs(
898904
...
899905
>>> rngs = nnx.Rngs(params=0, dropout=1)
900906
>>> _ = nnx.fork_rngs(rngs, split=5)
901-
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
907+
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
902908
((5,), (5,))
903909
904910
>>> rngs = nnx.Rngs(params=0, dropout=1)
905911
>>> _ = nnx.fork_rngs(rngs, split=(2, 5))
906-
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
912+
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
907913
((2, 5), (2, 5))
908914
909915
910916
>>> rngs = nnx.Rngs(params=0, dropout=1)
911917
>>> _ = nnx.fork_rngs(rngs, split={'params': 5})
912-
>>> rngs.params.key_.shape, rngs.dropout.key_.shape
918+
>>> rngs.params.base_key.shape, rngs.dropout.base_key.shape
913919
((5,), ())
914920
915921
Once forked, random state can be used with transforms like :func:`nnx.vmap`::
@@ -929,7 +935,7 @@ def fork_rngs(
929935
... return Model(rngs)
930936
...
931937
>>> model = create_model(rngs)
932-
>>> model.dropout.rngs.key_.shape
938+
>>> model.dropout.rngs.base_key.shape
933939
()
934940
935941
``fork_rngs`` returns a SplitBackups object that can be used to restore the
@@ -942,7 +948,7 @@ def fork_rngs(
942948
>>> model = create_model(rngs)
943949
>>> nnx.restore_rngs(backups)
944950
...
945-
>>> model.dropout.rngs.key_.shape
951+
>>> model.dropout.rngs.base_key.shape
946952
()
947953
948954
SplitBackups can also be used as a context manager to automatically restore
@@ -953,7 +959,7 @@ def fork_rngs(
953959
>>> with nnx.fork_rngs(rngs, split={'params': 5}):
954960
... model = create_model(rngs)
955961
...
956-
>>> model.dropout.rngs.key_.shape
962+
>>> model.dropout.rngs.base_key.shape
957963
()
958964
959965
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -965,7 +971,7 @@ def fork_rngs(
965971
...
966972
>>> rngs = nnx.Rngs(params=0, dropout=1)
967973
>>> model = create_model(rngs)
968-
>>> model.dropout.rngs.key_.shape
974+
>>> model.dropout.rngs.base_key.shape
969975
()
970976
"""
971977
if isinstance(node, Missing):
@@ -993,14 +999,14 @@ def fork_rngs_wrapper(*args, **kwargs):
993999
for predicate, splits in predicate_splits.items():
9941000
if (
9951001
isinstance(stream, RngStream)
996-
and predicate((*path, 'key'), stream.key_)
1002+
and predicate((*path, 'key'), stream.base_key)
9971003
and predicate((*path, 'count'), stream.count)
9981004
):
9991005
forked_stream = stream.fork(split=splits)
10001006
# backup the original stream state
1001-
backups.append((stream, stream.key_.raw_value, stream.count.raw_value))
1007+
backups.append((stream, stream.base_key.raw_value, stream.count.raw_value))
10021008
# apply the forked key and count to the original stream
1003-
stream.key_.raw_value = forked_stream.key_.raw_value
1009+
stream.base_key.raw_value = forked_stream.base_key.raw_value
10041010
stream.count.raw_value = forked_stream.count.raw_value
10051011

10061012
return SplitBackups(backups)
@@ -1010,7 +1016,7 @@ def backup_keys(node: tp.Any, /):
10101016
backups: list[StreamBackup] = []
10111017
for _, stream in graph.iter_graph(node):
10121018
if isinstance(stream, RngStream):
1013-
backups.append((stream, stream.key_.raw_value))
1019+
backups.append((stream, stream.base_key.raw_value))
10141020
return backups
10151021

10161022
def _scalars_only(
@@ -1055,7 +1061,7 @@ def reseed(
10551061
of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to
10561062
define a custom reseeding policy.
10571063
**stream_keys: a mapping of stream names to new keys. The keys can be
1058-
either integers or ``jax.random.key_``.
1064+
either integers or ``jax.random.base_key``.
10591065
10601066
Example::
10611067
@@ -1093,16 +1099,16 @@ def reseed(
10931099
rngs = Rngs(**stream_keys)
10941100
for path, stream in graph.iter_graph(node):
10951101
if isinstance(stream, RngStream):
1096-
if stream.key_.tag in stream_keys:
1097-
key = rngs[stream.key_.tag]()
1098-
key = policy(path, key, stream.key_.shape)
1099-
stream.key_.value = key
1102+
if stream.base_key.tag in stream_keys:
1103+
key = rngs[stream.base_key.tag]()
1104+
key = policy(path, key, stream.base_key.shape)
1105+
stream.base_key.value = key
11001106
stream.count.value = jnp.zeros(key.shape, dtype=jnp.uint32)
11011107

11021108

11031109
def restore_rngs(backups: tp.Iterable[StreamBackup], /):
11041110
for backup in backups:
11051111
stream = backup[0]
1106-
stream.key_.raw_value = backup[1]
1112+
stream.base_key.raw_value = backup[1]
11071113
if len(backup) == 3:
11081114
stream.count.raw_value = backup[2] # count

tests/nnx/bridge/module_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __call__(self):
149149
scope = bar.apply({}, rngs=1)
150150
self.assertIsNone(bar.scope)
151151

152-
self.assertEqual(scope.rngs.default.key_[...], jax.random.key(1))
152+
self.assertEqual(scope.rngs.default.base_key[...], jax.random.key(1))
153153
self.assertEqual(scope.rngs.default.count[...], 0)
154154

155155
class Baz(bridge.Module):

tests/nnx/module_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def test_create_abstract(self):
556556
def test_create_abstract_stateful(self):
557557
linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0)))
558558

559-
assert linear.rngs.key_.value == jax.ShapeDtypeStruct(
559+
assert linear.rngs.base_key.value == jax.ShapeDtypeStruct(
560560
(), jax.random.key(0).dtype
561561
)
562562

tests/nnx/mutable_array_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def test_rngs_create(self):
662662
paths[1],
663663
(
664664
jax.tree_util.GetAttrKey('default'),
665-
jax.tree_util.GetAttrKey('key_'),
665+
jax.tree_util.GetAttrKey('base_key'),
666666
jax.tree_util.GetAttrKey('value'),
667667
),
668668
)

tests/nnx/nn/attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_keep_rngs(self, keep_rngs):
128128
if keep_rngs:
129129
_, _, nondiff = nnx.split(module, nnx.Param, ...)
130130
assert isinstance(nondiff['rngs']['count'], nnx.RngCount)
131-
assert isinstance(nondiff['rngs']['key_'], nnx.RngKey)
131+
assert isinstance(nondiff['rngs']['base_key'], nnx.RngKey)
132132
else:
133133
nnx.split(module, nnx.Param)
134134

tests/nnx/rngs_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def test_rng_stream(self):
4545

4646
key1 = rngs.params()
4747
self.assertEqual(rngs.params.count[...], 1)
48-
self.assertIs(rngs.params.key_[...], key0)
48+
self.assertIs(rngs.params.base_key[...], key0)
4949
self.assertFalse(jnp.allclose(key0, key1))
5050

5151
key2 = rngs.params()
5252
self.assertEqual(rngs.params.count[...], 2)
53-
self.assertIs(rngs.params.key_[...], key0)
53+
self.assertIs(rngs.params.base_key[...], key0)
5454
self.assertFalse(jnp.allclose(key1, key2))
5555

5656
def test_rng_trace_level_constraints(self):

0 commit comments

Comments
 (0)