@@ -112,15 +112,15 @@ def __init__(
112112
113113 count = jnp .zeros (key .shape , dtype = jnp .uint32 )
114114 self .tag = tag
115- self .key_ = RngKey (key , tag = tag )
115+ self .base_key = RngKey (key , tag = tag )
116116 self .count = RngCount (count , tag = tag )
117117
118118 def __call__ (self ) -> jax .Array :
119119 if not self .count .has_ref and not self .count ._trace_state .is_valid ():
120120 raise errors .TraceContextError (
121121 f'Cannot mutate { type (self ).__name__ } from a different trace level'
122122 )
123- key = random .fold_in (self .key_ [...], self .count [...])
123+ key = random .fold_in (self .base_key [...], self .count [...])
124124 self .count [...] += 1
125125 return key
126126
@@ -325,7 +325,7 @@ class Rngs(Pytree):
325325 ``counter``. Every time a key is requested, the counter is incremented and the key is
326326 generated from the seed key and the counter by using ``jax.random.fold_in``.
327327
328- To create an ``Rngs`` pass in an integer or ``jax.random.key_ `` to the
328+ To create an ``Rngs`` pass in an integer or ``jax.random.base_key `` to the
329329 constructor as a keyword argument with the name of the stream. The key will be used as the
330330 starting seed for the stream, and the counter will be initialized to zero. Then call the
331331 stream to get a key::
@@ -378,7 +378,7 @@ def __init__(
378378 Args:
379379 default: the starting seed for the ``default`` stream, defaults to None.
380380 **rngs: keyword arguments specifying the starting seed for each stream.
381- The key can be an integer or a ``jax.random.key_ ``.
381+ The key can be an integer or a ``jax.random.base_key ``.
382382 """
383383 if default is not None :
384384 if isinstance (default , tp .Mapping ):
@@ -388,7 +388,7 @@ def __init__(
388388
389389 for tag , key in rngs .items ():
390390 if isinstance (key , RngStream ):
391- key = key .key_ [...]
391+ key = key .base_key [...]
392392 stream = RngStream (
393393 key = key ,
394394 tag = tag ,
@@ -415,6 +415,9 @@ def __getattr__(self, name: str):
415415 def __call__ (self ):
416416 return self .default ()
417417
418+ def key (self ):
419+ return self .default ()
420+
418421 def __iter__ (self ) -> tp .Iterator [str ]:
419422 for name , stream in vars (self ).items ():
420423 if isinstance (stream , RngStream ):
@@ -433,6 +436,9 @@ def items(self):
433436 if isinstance (stream , RngStream ):
434437 yield name , stream
435438
439+ def split (self , splits : int ):
440+ return self .fork (split = splits )
441+
436442 def fork (
437443 self ,
438444 / ,
@@ -457,8 +463,8 @@ def fork(
457463 >>> rngs = nnx.Rngs(params=1, dropout=2)
458464 >>> new_rngs = rngs.fork(split=5)
459465 ...
460- >>> assert new_rngs.params.key_ .shape == (5,)
461- >>> assert new_rngs.dropout.key_ .shape == (5,)
466+ >>> assert new_rngs.params.base_key .shape == (5,)
467+ >>> assert new_rngs.dropout.base_key .shape == (5,)
462468
463469 ``split`` also accepts a mapping of
464470 `Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to
@@ -471,9 +477,9 @@ def fork(
471477 ... ...: (2, 5), # split anything else into 2x5 keys
472478 ... })
473479 ...
474- >>> assert new_rngs.params.key_ .shape == (5,)
475- >>> assert new_rngs.dropout.key_ .shape == ()
476- >>> assert new_rngs.noise.key_ .shape == (2, 5)
480+ >>> assert new_rngs.params.base_key .shape == (5,)
481+ >>> assert new_rngs.dropout.base_key .shape == ()
482+ >>> assert new_rngs.noise.base_key .shape == (2, 5)
477483 """
478484 if split is None :
479485 split = {}
@@ -734,18 +740,18 @@ def split_rngs(
734740 ...
735741 >>> rngs = nnx.Rngs(params=0, dropout=1)
736742 >>> _ = nnx.split_rngs(rngs, splits=5)
737- >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
743+ >>> rngs.params.base_key .shape, rngs.dropout.base_key .shape
738744 ((5,), (5,))
739745
740746 >>> rngs = nnx.Rngs(params=0, dropout=1)
741747 >>> _ = nnx.split_rngs(rngs, splits=(2, 5))
742- >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
748+ >>> rngs.params.base_key .shape, rngs.dropout.base_key .shape
743749 ((2, 5), (2, 5))
744750
745751
746752 >>> rngs = nnx.Rngs(params=0, dropout=1)
747753 >>> _ = nnx.split_rngs(rngs, splits=5, only='params')
748- >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
754+ >>> rngs.params.base_key .shape, rngs.dropout.base_key .shape
749755 ((5,), ())
750756
751757 Once split, random state can be used with transforms like :func:`nnx.vmap`::
@@ -765,7 +771,7 @@ def split_rngs(
765771 ... return Model(rngs)
766772 ...
767773 >>> model = create_model(rngs)
768- >>> model.dropout.rngs.key_ .shape
774+ >>> model.dropout.rngs.base_key .shape
769775 ()
770776
771777 ``split_rngs`` returns a SplitBackups object that can be used to restore the
@@ -778,7 +784,7 @@ def split_rngs(
778784 >>> model = create_model(rngs)
779785 >>> nnx.restore_rngs(backups)
780786 ...
781- >>> model.dropout.rngs.key_ .shape
787+ >>> model.dropout.rngs.base_key .shape
782788 ()
783789
784790 SplitBackups can also be used as a context manager to automatically restore
@@ -789,7 +795,7 @@ def split_rngs(
789795 >>> with nnx.split_rngs(rngs, splits=5, only='params'):
790796 ... model = create_model(rngs)
791797 ...
792- >>> model.dropout.rngs.key_ .shape
798+ >>> model.dropout.rngs.base_key .shape
793799 ()
794800
795801 >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -801,7 +807,7 @@ def split_rngs(
801807 ...
802808 >>> rngs = nnx.Rngs(params=0, dropout=1)
803809 >>> model = create_model(rngs)
804- >>> model.dropout.rngs.key_ .shape
810+ >>> model.dropout.rngs.base_key .shape
805811 ()
806812
807813
@@ -828,18 +834,18 @@ def split_rngs_wrapper(*args, **kwargs):
828834 for path , stream in graph .iter_graph (node ):
829835 if (
830836 isinstance (stream , RngStream )
831- and predicate ((* path , 'key' ), stream .key_ )
837+ and predicate ((* path , 'key' ), stream .base_key )
832838 and predicate ((* path , 'count' ), stream .count )
833839 ):
834840 key = stream ()
835- backups .append ((stream , stream .key_ .raw_value , stream .count .raw_value ))
841+ backups .append ((stream , stream .base_key .raw_value , stream .count .raw_value ))
836842 key = random .split (key , splits )
837843 if squeeze :
838844 key = key [0 ]
839- if variablelib .is_array_ref (stream .key_ .raw_value ):
840- stream .key_ .raw_value = variablelib .new_ref (key ) # type: ignore[assignment]
845+ if variablelib .is_array_ref (stream .base_key .raw_value ):
846+ stream .base_key .raw_value = variablelib .new_ref (key ) # type: ignore[assignment]
841847 else :
842- stream .key_ .value = key
848+ stream .base_key .value = key
843849 if squeeze :
844850 counts_shape = stream .count .shape
845851 elif isinstance (splits , int ):
@@ -898,18 +904,18 @@ def fork_rngs(
898904 ...
899905 >>> rngs = nnx.Rngs(params=0, dropout=1)
900906 >>> _ = nnx.fork_rngs(rngs, split=5)
901- >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
907+ >>> rngs.params.base_key .shape, rngs.dropout.base_key .shape
902908 ((5,), (5,))
903909
904910 >>> rngs = nnx.Rngs(params=0, dropout=1)
905911 >>> _ = nnx.fork_rngs(rngs, split=(2, 5))
906- >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
912+ >>> rngs.params.base_key .shape, rngs.dropout.base_key .shape
907913 ((2, 5), (2, 5))
908914
909915
910916 >>> rngs = nnx.Rngs(params=0, dropout=1)
911917 >>> _ = nnx.fork_rngs(rngs, split={'params': 5})
912- >>> rngs.params.key_ .shape, rngs.dropout.key_ .shape
918+ >>> rngs.params.base_key .shape, rngs.dropout.base_key .shape
913919 ((5,), ())
914920
915921 Once forked, random state can be used with transforms like :func:`nnx.vmap`::
@@ -929,7 +935,7 @@ def fork_rngs(
929935 ... return Model(rngs)
930936 ...
931937 >>> model = create_model(rngs)
932- >>> model.dropout.rngs.key_ .shape
938+ >>> model.dropout.rngs.base_key .shape
933939 ()
934940
935941 ``fork_rngs`` returns a SplitBackups object that can be used to restore the
@@ -942,7 +948,7 @@ def fork_rngs(
942948 >>> model = create_model(rngs)
943949 >>> nnx.restore_rngs(backups)
944950 ...
945- >>> model.dropout.rngs.key_ .shape
951+ >>> model.dropout.rngs.base_key .shape
946952 ()
947953
948954 SplitBackups can also be used as a context manager to automatically restore
@@ -953,7 +959,7 @@ def fork_rngs(
953959 >>> with nnx.fork_rngs(rngs, split={'params': 5}):
954960 ... model = create_model(rngs)
955961 ...
956- >>> model.dropout.rngs.key_ .shape
962+ >>> model.dropout.rngs.base_key .shape
957963 ()
958964
959965 >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
@@ -965,7 +971,7 @@ def fork_rngs(
965971 ...
966972 >>> rngs = nnx.Rngs(params=0, dropout=1)
967973 >>> model = create_model(rngs)
968- >>> model.dropout.rngs.key_ .shape
974+ >>> model.dropout.rngs.base_key .shape
969975 ()
970976 """
971977 if isinstance (node , Missing ):
@@ -993,14 +999,14 @@ def fork_rngs_wrapper(*args, **kwargs):
993999 for predicate , splits in predicate_splits .items ():
9941000 if (
9951001 isinstance (stream , RngStream )
996- and predicate ((* path , 'key' ), stream .key_ )
1002+ and predicate ((* path , 'key' ), stream .base_key )
9971003 and predicate ((* path , 'count' ), stream .count )
9981004 ):
9991005 forked_stream = stream .fork (split = splits )
10001006 # backup the original stream state
1001- backups .append ((stream , stream .key_ .raw_value , stream .count .raw_value ))
1007+ backups .append ((stream , stream .base_key .raw_value , stream .count .raw_value ))
10021008 # apply the forked key and count to the original stream
1003- stream .key_ .raw_value = forked_stream .key_ .raw_value
1009+ stream .base_key .raw_value = forked_stream .base_key .raw_value
10041010 stream .count .raw_value = forked_stream .count .raw_value
10051011
10061012 return SplitBackups (backups )
@@ -1010,7 +1016,7 @@ def backup_keys(node: tp.Any, /):
10101016 backups : list [StreamBackup ] = []
10111017 for _ , stream in graph .iter_graph (node ):
10121018 if isinstance (stream , RngStream ):
1013- backups .append ((stream , stream .key_ .raw_value ))
1019+ backups .append ((stream , stream .base_key .raw_value ))
10141020 return backups
10151021
10161022def _scalars_only (
@@ -1055,7 +1061,7 @@ def reseed(
10551061 of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to
10561062 define a custom reseeding policy.
10571063 **stream_keys: a mapping of stream names to new keys. The keys can be
1058- either integers or ``jax.random.key_ ``.
1064+ either integers or ``jax.random.base_key ``.
10591065
10601066 Example::
10611067
@@ -1093,16 +1099,16 @@ def reseed(
10931099 rngs = Rngs (** stream_keys )
10941100 for path , stream in graph .iter_graph (node ):
10951101 if isinstance (stream , RngStream ):
1096- if stream .key_ .tag in stream_keys :
1097- key = rngs [stream .key_ .tag ]()
1098- key = policy (path , key , stream .key_ .shape )
1099- stream .key_ .value = key
1102+ if stream .base_key .tag in stream_keys :
1103+ key = rngs [stream .base_key .tag ]()
1104+ key = policy (path , key , stream .base_key .shape )
1105+ stream .base_key .value = key
11001106 stream .count .value = jnp .zeros (key .shape , dtype = jnp .uint32 )
11011107
11021108
11031109def restore_rngs (backups : tp .Iterable [StreamBackup ], / ):
11041110 for backup in backups :
11051111 stream = backup [0 ]
1106- stream .key_ .raw_value = backup [1 ]
1112+ stream .base_key .raw_value = backup [1 ]
11071113 if len (backup ) == 3 :
11081114 stream .count .raw_value = backup [2 ] # count
0 commit comments