diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 15e0ecc8a..ca1159b68 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -1096,6 +1096,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref: IndexMap | None = None, outer_index_outer_ref: IndexMap | None = None, copy_variables: bool = False, + recreate_variables: bool = True, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -1157,6 +1158,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref, outer_index_outer_ref, copy_variables, + recreate_variables ) try: @@ -1178,6 +1180,7 @@ def _graph_unflatten( index_ref: IndexMap, outer_index_outer_ref: IndexMap | None, copy_variables: bool, + recreate_variables: bool ) -> Node: """Recursive helper for graph_unflatten. @@ -1272,7 +1275,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): variable.set_raw_value(value) else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one - if isinstance(value, Variable): + if isinstance(value, Variable) or not recreate_variables: variable = value else: variable = variabledef.type.from_metadata( @@ -1321,6 +1324,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: index_ref, outer_index_outer_ref, copy_variables, + recreate_variables ) else: raise RuntimeError(f'Unknown node definition: {node_def!r}') @@ -2366,6 +2370,7 @@ def merge( # type: ignore[invalid-annotation] /, *states: tp.Any, copy: bool = False, + recreate_variables: bool = True, ) -> A: """The inverse of :func:`flax.nnx.split`. @@ -2417,7 +2422,7 @@ def merge( # type: ignore[invalid-annotation] _state = state else: _state = _merge_to_flat_state((state, *states)) - node = unflatten(graphdef, _state, copy_variables=copy) + node = unflatten(graphdef, _state, copy_variables=copy, recreate_variables=recreate_variables) return node @@ -2541,6 +2546,7 @@ def map( /, *, graph: bool | None = None, + recreate_variables: bool = True, ) -> A: """Map a function over the state of a graph node. @@ -2574,7 +2580,7 @@ def map( """ graphdef, state = split(node, graph=graph) state = statelib.map_state(f, state) - return merge(graphdef, state) + return merge(graphdef, state, recreate_variables=recreate_variables) def graphdef( diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 97bbf4b81..b176f0361 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1699,6 +1699,14 @@ def test_map_replace(self): np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3))) np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,))) + def test_map_recreate_variables_false(self): + rngs = nnx.Rngs(0) + new_rngs = nnx.map( + lambda path, x: 0, rngs, recreate_variables=False + ) + self.assertNotIsInstance(new_rngs.default.count, nnx.Variable) + self.assertEqual(new_rngs.default.count, 0) + if __name__ == '__main__': absltest.main()