Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions docs_nnx/hijax/hijax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"@jax.jit\n",
"def train_step(x, y):\n",
" loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad\n",
" optimizer.update(model, grads)\n",
" return loss\n",
"\n",
Expand Down Expand Up @@ -297,8 +297,8 @@
"\n",
"model = Linear(1, 3, rngs=nnx.Rngs(0))\n",
"\n",
"print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n",
"print(f\"{nnx.vars_as(model, mutable=True) = !s}\")"
"print(f\"{nnx.with_vars(model, mutable=False) = !s}\")\n",
"print(f\"{nnx.with_vars(model, mutable=True) = !s}\")"
]
},
{
Expand All @@ -317,7 +317,7 @@
],
"source": [
"v = nnx.Variable(jnp.array(0))\n",
"v_immut = nnx.vars_as(v, mutable=False)\n",
"v_immut = nnx.with_vars(v, mutable=False)\n",
"assert not v_immut.mutable\n",
"\n",
"try:\n",
Expand Down Expand Up @@ -355,7 +355,7 @@
],
"source": [
"v = nnx.Variable(jnp.array(0))\n",
"v_ref = nnx.vars_as(v, ref=True)\n",
"v_ref = nnx.with_vars(v, ref=True)\n",
"assert v_ref.ref\n",
"print(v_ref)\n",
"print(v_ref.get_raw_value())"
Expand Down Expand Up @@ -386,11 +386,11 @@
}
],
"source": [
"v_immut = nnx.vars_as(v_ref, mutable=False)\n",
"v_immut = nnx.with_vars(v_ref, mutable=False)\n",
"assert not v_immut.ref\n",
"print(\"immutable =\", v_immut)\n",
"\n",
"v_ref = nnx.vars_as(v_immut, mutable=True)\n",
"v_ref = nnx.with_vars(v_immut, mutable=True)\n",
"assert v_ref.ref\n",
"print(\"mutable =\", v_ref)"
]
Expand Down Expand Up @@ -458,7 +458,7 @@
" model = nnx.merge(graphdef, params, nondiff)\n",
" return ((model(x) - y) ** 2).mean()\n",
"\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n",
" loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad\n",
" optimizer.update(model, grads)\n",
"\n",
" return loss\n",
Expand Down Expand Up @@ -563,9 +563,9 @@
"source": [
"@jax.jit\n",
"def create_model(rngs):\n",
" return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n",
" return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)\n",
"\n",
"model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n",
"model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)\n",
"\n",
"print(\"model.linear =\", model.linear)"
]
Expand Down
20 changes: 10 additions & 10 deletions docs_nnx/hijax/hijax.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)
@jax.jit
def train_step(x, y):
loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad
loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad
optimizer.update(model, grads)
return loss

Expand Down Expand Up @@ -112,13 +112,13 @@ class Linear(nnx.Module):

model = Linear(1, 3, rngs=nnx.Rngs(0))

print(f"{nnx.vars_as(model, mutable=False) = !s}")
print(f"{nnx.vars_as(model, mutable=True) = !s}")
print(f"{nnx.with_vars(model, mutable=False) = !s}")
print(f"{nnx.with_vars(model, mutable=True) = !s}")
```

```{code-cell} ipython3
v = nnx.Variable(jnp.array(0))
v_immut = nnx.vars_as(v, mutable=False)
v_immut = nnx.with_vars(v, mutable=False)
assert not v_immut.mutable

try:
Expand All @@ -131,18 +131,18 @@ except Exception as e:

```{code-cell} ipython3
v = nnx.Variable(jnp.array(0))
v_ref = nnx.vars_as(v, ref=True)
v_ref = nnx.with_vars(v, ref=True)
assert v_ref.ref
print(v_ref)
print(v_ref.get_raw_value())
```

```{code-cell} ipython3
v_immut = nnx.vars_as(v_ref, mutable=False)
v_immut = nnx.with_vars(v_ref, mutable=False)
assert not v_immut.ref
print("immutable =", v_immut)

v_ref = nnx.vars_as(v_immut, mutable=True)
v_ref = nnx.with_vars(v_immut, mutable=True)
assert v_ref.ref
print("mutable =", v_ref)
```
Expand Down Expand Up @@ -176,7 +176,7 @@ def train_step(model, optimizer, x, y):
model = nnx.merge(graphdef, params, nondiff)
return ((model(x) - y) ** 2).mean()

loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad
loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad
optimizer.update(model, grads)

return loss
Expand Down Expand Up @@ -226,9 +226,9 @@ except Exception as e:
```{code-cell} ipython3
@jax.jit
def create_model(rngs):
return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)
return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)

model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)
model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)

print("model.linear =", model.linear)
```
Expand Down
2 changes: 1 addition & 1 deletion examples/nnx_toy_examples/hijax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def loss_fn(params):
model = nnx.merge(graphdef, params, nondiff)
return jnp.mean((y - model(x)) ** 2)

grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False))
grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False))
optimizer.update(model, grads)

@jax.jit
Expand Down
2 changes: 1 addition & 1 deletion examples/nnx_toy_examples/hijax_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def loss_fn(params):

# For the time being we have to use 'immutable'
# as 'jax.grad' doesn't support QDD types yet.
grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False))
grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False))
# 'update' mutates the optimizer's state and the params in place
# so we don't need to return anything 🚀
optimizer.update(params, grads)
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .graphlib import MergeContext as MergeContext
from .graphlib import merge_context as merge_context
from .graphlib import variables as variables
from .graphlib import with_vars as with_vars
from .graphlib import vars_as as vars_as
from .graphlib import pure as pure
from .graphlib import cached_partial as cached_partial
Expand Down
29 changes: 19 additions & 10 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flax import config
from flax.nnx import filterlib, reprlib, traversals, variablelib
from flax.nnx import statelib
from flax.nnx.deprecations import deprecated
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
Expand Down Expand Up @@ -1096,6 +1097,7 @@ def unflatten( # type: ignore[invalid-annotation]
index_ref: IndexMap | None = None,
outer_index_outer_ref: IndexMap | None = None,
copy_variables: bool = False,
auto_create_variables: bool = True,
) -> Node:
"""Unflattens a graphdef into a node with the given state.

Expand Down Expand Up @@ -1157,6 +1159,7 @@ def unflatten( # type: ignore[invalid-annotation]
index_ref,
outer_index_outer_ref,
copy_variables,
auto_create_variables
)

try:
Expand All @@ -1178,6 +1181,7 @@ def _graph_unflatten(
index_ref: IndexMap,
outer_index_outer_ref: IndexMap | None,
copy_variables: bool,
auto_create_variables: bool
) -> Node:
"""Recursive helper for graph_unflatten.

Expand Down Expand Up @@ -1272,7 +1276,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
variable.set_raw_value(value)
else: # variabledef.index not in index_ref_cache
# variable reference does not exist outside, create a new one
if isinstance(value, Variable):
if isinstance(value, Variable) or not auto_create_variables:
variable = value
else:
variable = variabledef.type.from_metadata(
Expand Down Expand Up @@ -1321,6 +1325,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
index_ref,
outer_index_outer_ref,
copy_variables,
auto_create_variables
)
else:
raise RuntimeError(f'Unknown node definition: {node_def!r}')
Expand Down Expand Up @@ -2366,6 +2371,7 @@ def merge( # type: ignore[invalid-annotation]
/,
*states: tp.Any,
copy: bool = False,
auto_create_variables: bool = True,
) -> A:
"""The inverse of :func:`flax.nnx.split`.

Expand Down Expand Up @@ -2417,7 +2423,7 @@ def merge( # type: ignore[invalid-annotation]
_state = state
else:
_state = _merge_to_flat_state((state, *states))
node = unflatten(graphdef, _state, copy_variables=copy)
node = unflatten(graphdef, _state, copy_variables=copy, auto_create_variables=auto_create_variables)
return node


Expand Down Expand Up @@ -2541,6 +2547,7 @@ def map(
/,
*,
graph: bool | None = None,
auto_create_variables: bool = True,
) -> A:
"""Map a function over the state of a graph node.

Expand Down Expand Up @@ -2573,8 +2580,11 @@ def map(
A :class:`State` with the mapped values.
"""
graphdef, state = split(node, graph=graph)
state = statelib.map_state(f, state)
return merge(graphdef, state)
if isinstance(state, statelib.State):
state = statelib.map_state(f, state)
else:
state = f((), state)
return merge(graphdef, state, auto_create_variables=auto_create_variables)


def graphdef(
Expand Down Expand Up @@ -2713,7 +2723,7 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N
return merge(graphdef, state, copy=variables)


def vars_as(
def with_vars(
node: A,
/,
*,
Expand Down Expand Up @@ -2751,18 +2761,17 @@ def _different_vars(path, x):
duplicates_strs += '\n ---'
raise ValueError(f'Found duplicate at paths:{duplicates_strs}')

def _to_refs(jax_path, x):
if predicate(jax_to_nnx_path(jax_path), x):
def _to_refs(path, x):
if predicate(path, x):
assert isinstance(x, Variable)
variable = x.copy(**new_attrs)
return variable
return x

node = jax.tree.map_with_path(
_to_refs, node, is_leaf=lambda x: isinstance(x, Variable)
)
node = map(_to_refs, node, auto_create_variables=False)
return node

vars_as = deprecated(with_vars)

def pure(tree: A) -> A:
"""Returns a new tree with all ``Variable`` objects replaced with inner values.
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,4 +1061,4 @@ def _maybe_int(x):
return x

def _get_str(x):
return x if isinstance(x, str) else str(x)
return x if isinstance(x, str) else str(x)
8 changes: 8 additions & 0 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,14 @@ def test_map_replace(self):
np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3)))
np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,)))

def test_map_auto_create_variables_false(self):
rngs = nnx.Rngs(0)
new_rngs = nnx.map(
lambda path, x: 0, rngs, auto_create_variables=False
)
self.assertNotIsInstance(new_rngs.default.count, nnx.Variable)
self.assertEqual(new_rngs.default.count, 0)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion tests/nnx/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def loss_fn(params):
return ((model(x) - y) ** 2).mean() # call methods directly

loss, grads = jax.value_and_grad(loss_fn)(
nnx.vars_as(params, hijax=False)
nnx.with_vars(params, hijax=False)
)
optimizer.update(model, grads) # in-place updates

Expand Down
Loading
Loading