Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ def unflatten( # type: ignore[invalid-annotation]
index_ref: IndexMap | None = None,
outer_index_outer_ref: IndexMap | None = None,
copy_variables: bool = False,
recreate_variables: bool = True,
) -> Node:
"""Unflattens a graphdef into a node with the given state.

Expand Down Expand Up @@ -1157,6 +1158,7 @@ def unflatten( # type: ignore[invalid-annotation]
index_ref,
outer_index_outer_ref,
copy_variables,
recreate_variables
)

try:
Expand All @@ -1178,6 +1180,7 @@ def _graph_unflatten(
index_ref: IndexMap,
outer_index_outer_ref: IndexMap | None,
copy_variables: bool,
recreate_variables: bool
) -> Node:
"""Recursive helper for graph_unflatten.

Expand Down Expand Up @@ -1272,7 +1275,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
variable.set_raw_value(value)
else: # variabledef.index not in index_ref_cache
# variable reference does not exist outside, create a new one
if isinstance(value, Variable):
if isinstance(value, Variable) or not recreate_variables:
variable = value
else:
variable = variabledef.type.from_metadata(
Expand Down Expand Up @@ -1321,6 +1324,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
index_ref,
outer_index_outer_ref,
copy_variables,
recreate_variables
)
else:
raise RuntimeError(f'Unknown node definition: {node_def!r}')
Expand Down Expand Up @@ -2366,6 +2370,7 @@ def merge( # type: ignore[invalid-annotation]
/,
*states: tp.Any,
copy: bool = False,
recreate_variables: bool = True,
) -> A:
"""The inverse of :func:`flax.nnx.split`.

Expand Down Expand Up @@ -2417,7 +2422,7 @@ def merge( # type: ignore[invalid-annotation]
_state = state
else:
_state = _merge_to_flat_state((state, *states))
node = unflatten(graphdef, _state, copy_variables=copy)
node = unflatten(graphdef, _state, copy_variables=copy, recreate_variables=recreate_variables)
return node


Expand Down Expand Up @@ -2541,6 +2546,7 @@ def map(
/,
*,
graph: bool | None = None,
recreate_variables: bool = True,
) -> A:
"""Map a function over the state of a graph node.

Expand Down Expand Up @@ -2574,7 +2580,7 @@ def map(
"""
graphdef, state = split(node, graph=graph)
state = statelib.map_state(f, state)
return merge(graphdef, state)
return merge(graphdef, state, recreate_variables=recreate_variables)


def graphdef(
Expand Down
8 changes: 8 additions & 0 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,14 @@ def test_map_replace(self):
np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3)))
np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,)))

def test_map_recreate_variables_false(self):
rngs = nnx.Rngs(0)
new_rngs = nnx.map(
lambda path, x: 0, rngs, recreate_variables=False
)
self.assertNotIsInstance(new_rngs.default.count, nnx.Variable)
self.assertEqual(new_rngs.default.count, 0)


if __name__ == '__main__':
absltest.main()
Loading