diff --git a/docs_nnx/hijax/hijax.ipynb b/docs_nnx/hijax/hijax.ipynb index 4b1c8d40f..bcb7c83ce 100644 --- a/docs_nnx/hijax/hijax.ipynb +++ b/docs_nnx/hijax/hijax.ipynb @@ -49,7 +49,7 @@ "@jax.jit\n", "def train_step(x, y):\n", " loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n", - " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n", + " loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad\n", " optimizer.update(model, grads)\n", " return loss\n", "\n", @@ -297,8 +297,8 @@ "\n", "model = Linear(1, 3, rngs=nnx.Rngs(0))\n", "\n", - "print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n", - "print(f\"{nnx.vars_as(model, mutable=True) = !s}\")" + "print(f\"{nnx.with_vars(model, mutable=False) = !s}\")\n", + "print(f\"{nnx.with_vars(model, mutable=True) = !s}\")" ] }, { @@ -317,7 +317,7 @@ ], "source": [ "v = nnx.Variable(jnp.array(0))\n", - "v_immut = nnx.vars_as(v, mutable=False)\n", + "v_immut = nnx.with_vars(v, mutable=False)\n", "assert not v_immut.mutable\n", "\n", "try:\n", @@ -355,7 +355,7 @@ ], "source": [ "v = nnx.Variable(jnp.array(0))\n", - "v_ref = nnx.vars_as(v, ref=True)\n", + "v_ref = nnx.with_vars(v, ref=True)\n", "assert v_ref.ref\n", "print(v_ref)\n", "print(v_ref.get_raw_value())" @@ -386,11 +386,11 @@ } ], "source": [ - "v_immut = nnx.vars_as(v_ref, mutable=False)\n", + "v_immut = nnx.with_vars(v_ref, mutable=False)\n", "assert not v_immut.ref\n", "print(\"immutable =\", v_immut)\n", "\n", - "v_ref = nnx.vars_as(v_immut, mutable=True)\n", + "v_ref = nnx.with_vars(v_immut, mutable=True)\n", "assert v_ref.ref\n", "print(\"mutable =\", v_ref)" ] @@ -458,7 +458,7 @@ " model = nnx.merge(graphdef, params, nondiff)\n", " return ((model(x) - y) ** 2).mean()\n", "\n", - " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n", + " loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad\n", " optimizer.update(model, grads)\n", "\n", " return loss\n", @@ -563,9 +563,9 @@ "source": [ "@jax.jit\n", "def create_model(rngs):\n", - " return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", + " return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", "\n", - "model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n", + "model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)\n", "\n", "print(\"model.linear =\", model.linear)" ] diff --git a/docs_nnx/hijax/hijax.md b/docs_nnx/hijax/hijax.md index 6d6123022..be1a65796 100644 --- a/docs_nnx/hijax/hijax.md +++ b/docs_nnx/hijax/hijax.md @@ -29,7 +29,7 @@ optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param) @jax.jit def train_step(x, y): loss_fn = lambda m: jnp.mean((m(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad + loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad optimizer.update(model, grads) return loss @@ -112,13 +112,13 @@ class Linear(nnx.Module): model = Linear(1, 3, rngs=nnx.Rngs(0)) -print(f"{nnx.vars_as(model, mutable=False) = !s}") -print(f"{nnx.vars_as(model, mutable=True) = !s}") +print(f"{nnx.with_vars(model, mutable=False) = !s}") +print(f"{nnx.with_vars(model, mutable=True) = !s}") ``` ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0)) -v_immut = nnx.vars_as(v, mutable=False) +v_immut = nnx.with_vars(v, mutable=False) assert not v_immut.mutable try: @@ -131,18 +131,18 @@ except Exception as e: ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0)) -v_ref = nnx.vars_as(v, ref=True) +v_ref = nnx.with_vars(v, ref=True) assert v_ref.ref print(v_ref) print(v_ref.get_raw_value()) ``` ```{code-cell} ipython3 -v_immut = nnx.vars_as(v_ref, mutable=False) +v_immut = nnx.with_vars(v_ref, mutable=False) assert not v_immut.ref print("immutable =", v_immut) -v_ref = nnx.vars_as(v_immut, mutable=True) +v_ref = nnx.with_vars(v_immut, mutable=True) assert v_ref.ref print("mutable =", v_ref) ``` @@ -176,7 +176,7 @@ def train_step(model, optimizer, x, y): model = nnx.merge(graphdef, params, nondiff) return ((model(x) - y) ** 2).mean() - loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad + loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad optimizer.update(model, grads) return loss @@ -226,9 +226,9 @@ except Exception as e: ```{code-cell} ipython3 @jax.jit def create_model(rngs): - return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False) + return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False) -model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True) +model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True) print("model.linear =", model.linear) ``` diff --git a/examples/nnx_toy_examples/hijax_basic.py b/examples/nnx_toy_examples/hijax_basic.py index 9ffe2b67d..fde8a76d2 100644 --- a/examples/nnx_toy_examples/hijax_basic.py +++ b/examples/nnx_toy_examples/hijax_basic.py @@ -68,7 +68,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((y - model(x)) ** 2) - grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False)) + grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False)) optimizer.update(model, grads) @jax.jit diff --git a/examples/nnx_toy_examples/hijax_demo.py b/examples/nnx_toy_examples/hijax_demo.py index 5380fc592..78863483a 100644 --- a/examples/nnx_toy_examples/hijax_demo.py +++ b/examples/nnx_toy_examples/hijax_demo.py @@ -238,7 +238,7 @@ def loss_fn(params): # For the time being we have to use 'immutable' # as 'jax.grad' doesn't support QDD types yet. - grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False)) + grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False)) # 'update' mutates the optimizer's state and the params in place # so we don't need to return anything 🚀 optimizer.update(params, grads) diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 84a994dd1..fc7a6c3f0 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -75,6 +75,7 @@ from .graphlib import MergeContext as MergeContext from .graphlib import merge_context as merge_context from .graphlib import variables as variables +from .graphlib import with_vars as with_vars from .graphlib import vars_as as vars_as from .graphlib import pure as pure from .graphlib import cached_partial as cached_partial diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 15e0ecc8a..c9e973fe9 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -26,6 +26,7 @@ from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib from flax.nnx import statelib +from flax.nnx.deprecations import deprecated from flax.nnx.proxy_caller import ( ApplyCaller, CallableProxy, @@ -1096,6 +1097,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref: IndexMap | None = None, outer_index_outer_ref: IndexMap | None = None, copy_variables: bool = False, + auto_create_variables: bool = True, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -1157,6 +1159,7 @@ def unflatten( # type: ignore[invalid-annotation] index_ref, outer_index_outer_ref, copy_variables, + auto_create_variables ) try: @@ -1178,6 +1181,7 @@ def _graph_unflatten( index_ref: IndexMap, outer_index_outer_ref: IndexMap | None, copy_variables: bool, + auto_create_variables: bool ) -> Node: """Recursive helper for graph_unflatten. @@ -1272,7 +1276,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): variable.set_raw_value(value) else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one - if isinstance(value, Variable): + if isinstance(value, Variable) or not auto_create_variables: variable = value else: variable = variabledef.type.from_metadata( @@ -1321,6 +1325,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: index_ref, outer_index_outer_ref, copy_variables, + auto_create_variables ) else: raise RuntimeError(f'Unknown node definition: {node_def!r}') @@ -2366,6 +2371,7 @@ def merge( # type: ignore[invalid-annotation] /, *states: tp.Any, copy: bool = False, + auto_create_variables: bool = True, ) -> A: """The inverse of :func:`flax.nnx.split`. @@ -2417,7 +2423,7 @@ def merge( # type: ignore[invalid-annotation] _state = state else: _state = _merge_to_flat_state((state, *states)) - node = unflatten(graphdef, _state, copy_variables=copy) + node = unflatten(graphdef, _state, copy_variables=copy, auto_create_variables=auto_create_variables) return node @@ -2541,6 +2547,7 @@ def map( /, *, graph: bool | None = None, + auto_create_variables: bool = True, ) -> A: """Map a function over the state of a graph node. @@ -2573,8 +2580,11 @@ def map( A :class:`State` with the mapped values. """ graphdef, state = split(node, graph=graph) - state = statelib.map_state(f, state) - return merge(graphdef, state) + if isinstance(state, statelib.State): + state = statelib.map_state(f, state) + else: + state = f((), state) + return merge(graphdef, state, auto_create_variables=auto_create_variables) def graphdef( @@ -2713,7 +2723,7 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N return merge(graphdef, state, copy=variables) -def vars_as( +def with_vars( node: A, /, *, @@ -2751,18 +2761,17 @@ def _different_vars(path, x): duplicates_strs += '\n ---' raise ValueError(f'Found duplicate at paths:{duplicates_strs}') - def _to_refs(jax_path, x): - if predicate(jax_to_nnx_path(jax_path), x): + def _to_refs(path, x): + if predicate(path, x): assert isinstance(x, Variable) variable = x.copy(**new_attrs) return variable return x - node = jax.tree.map_with_path( - _to_refs, node, is_leaf=lambda x: isinstance(x, Variable) - ) + node = map(_to_refs, node, auto_create_variables=False) return node +vars_as = deprecated(with_vars) def pure(tree: A) -> A: """Returns a new tree with all ``Variable`` objects replaced with inner values. diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index c53a71314..593205779 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -1061,4 +1061,4 @@ def _maybe_int(x): return x def _get_str(x): - return x if isinstance(x, str) else str(x) \ No newline at end of file + return x if isinstance(x, str) else str(x) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 97bbf4b81..fe1429695 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1699,6 +1699,14 @@ def test_map_replace(self): np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3))) np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,))) + def test_map_auto_create_variables_false(self): + rngs = nnx.Rngs(0) + new_rngs = nnx.map( + lambda path, x: 0, rngs, auto_create_variables=False + ) + self.assertNotIsInstance(new_rngs.default.count, nnx.Variable) + self.assertEqual(new_rngs.default.count, 0) + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 6f16842cb..e4b3e1f0f 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -487,7 +487,7 @@ def loss_fn(params): return ((model(x) - y) ** 2).mean() # call methods directly loss, grads = jax.value_and_grad(loss_fn)( - nnx.vars_as(params, hijax=False) + nnx.with_vars(params, hijax=False) ) optimizer.update(model, grads) # in-place updates diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index 570d6910b..6597d1f20 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -29,7 +29,7 @@ def __init__(self): self.node = jnp.array(1) self.meta = 1 - m = nnx.vars_as(Foo(), ref=True) + m = nnx.with_vars(Foo(), ref=True) m = jax.tree.map(lambda x: x + 1, m) @@ -126,17 +126,17 @@ def test_split_mutable_array(self): def test_to_arrays_example(self): node = [nnx.Variable(1.0), nnx.Variable(2.0, mode='ref')] - mutable_node = nnx.vars_as(node, ref=True) + mutable_node = nnx.with_vars(node, ref=True) assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) assert isinstance(mutable_node[1].get_raw_value(), jax.Ref) shared_array = nnx.Variable(1.0, mode='pytree') node = [shared_array, shared_array] with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): - nnx.vars_as(node, ref=True) + nnx.with_vars(node, ref=True) node = [nnx.Variable(1.0), nnx.Variable(2.0)] - mutable_node = nnx.vars_as( + mutable_node = nnx.with_vars( node, ref=True, only=lambda path, x: path[0] == 0 ) assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) @@ -148,18 +148,18 @@ def __init__(self): self.a = nnx.Param(1) self.b = nnx.BatchStat(2) - m = nnx.vars_as(Foo(), hijax=True, ref=True) + m = nnx.with_vars(Foo(), hijax=True, ref=True) self.assertEqual(m.a.ref, True) self.assertEqual(m.b.ref, True) - m2 = nnx.vars_as(m, hijax=False, only=nnx.BatchStat) + m2 = nnx.with_vars(m, hijax=False, only=nnx.BatchStat) self.assertEqual(m2.a.ref, True) self.assertEqual(m2.a.hijax, True) self.assertEqual(m2.b.ref, True) self.assertEqual(m2.b.hijax, False) self.assertIsNot(m, m2) - m3 = nnx.vars_as(m2, hijax=True, only=nnx.BatchStat) + m3 = nnx.with_vars(m2, hijax=True, only=nnx.BatchStat) self.assertEqual(m3.a.ref, True) self.assertEqual(m3.b.ref, True) self.assertEqual(m3.b.hijax, True) @@ -175,7 +175,7 @@ def __init__(self): m = Foo() with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): - nnx.vars_as(m, ref=True) + nnx.with_vars(m, ref=True) def test_mutable_array_split(self): class Foo(nnx.Module): @@ -233,7 +233,7 @@ def test_mutable_example(self): tree = [nnx.Variable(1.0), nnx.Variable(2.0, ref=True)] assert tree[0].ref == False assert tree[1].ref == True - mutable_tree = nnx.vars_as(tree, ref=True) + mutable_tree = nnx.with_vars(tree, ref=True) assert isinstance(mutable_tree[0].get_raw_value(), jax.Ref) assert isinstance(mutable_tree[1].get_raw_value(), jax.Ref) @@ -247,7 +247,7 @@ def __init__(self): ref_map = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_map, graph=True) - state = nnx.vars_as(state, hijax=False) + state = nnx.with_vars(state, hijax=False) self.assertLen(state, 1) m1 = nnx.merge(graphdef, state) @@ -255,7 +255,7 @@ def __init__(self): self.assertIsInstance(m1.a, jax.Ref) def test_update_context(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.split(m1) @@ -263,7 +263,7 @@ def test_update_context(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.split((m2, m_out1, m2)) @@ -290,7 +290,7 @@ def test_update_context(self): self.assertIsNot(m_out2, m_out1) def test_update_context_flatten(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.flatten(m1) @@ -298,7 +298,7 @@ def test_update_context_flatten(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.flatten((m2, m_out1, m2)) @@ -327,13 +327,13 @@ def test_update_context_flatten(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree1(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) @@ -366,13 +366,13 @@ def test_update_context_to_tree1(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree2(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example') as ctx: m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) @@ -405,13 +405,13 @@ def test_update_context_to_tree2(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree_trivial_prefix(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example', prefix=0) (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True, prefix=0) - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) @@ -444,13 +444,13 @@ def test_update_context_to_tree_trivial_prefix(self): self.assertIsNot(m_out2, m_out1) def test_simple_jit(self): - m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) m_out1 = None @nnx.jit def f(m2): nonlocal m_out1 - m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) + m_out1 = nnx.with_vars(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) return m_out1 m_out2 = f(m1) @@ -516,7 +516,7 @@ def __init__(self, din: int, dout: int): self.count = nnx.Variable(jnp.array(0)) params = Params(3, 4) - params = nnx.vars_as(params, ref=True) + params = nnx.with_vars(params, ref=True) paths_leaves, treedef = jax.tree.flatten_with_path(params) paths, leaves = zip(*paths_leaves) @@ -682,7 +682,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, hijax=False)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, hijax=False)) optimizer.update(params, grads) return loss @@ -693,7 +693,7 @@ def loss_fn(params): class TestHijaxVariables(parameterized.TestCase): def test_variable_to_hijax(self): v_low = nnx.Param(jnp.array(1), a='hi') - v_hi = nnx.vars_as(v_low, hijax=True) + v_hi = nnx.with_vars(v_low, hijax=True) self.assertTrue(v_hi.hijax) self.assertEqual(v_hi[...], 1) @@ -715,7 +715,7 @@ def set(v_hi, a): self.assertEqual(v_hi[...], 15) self.assertEqual(y, 17) - v_low = nnx.vars_as(v_hi, hijax=False) + v_low = nnx.with_vars(v_hi, hijax=False) self.assertIsInstance(v_low, nnx.Param) self.assertFalse(v_low.hijax) self.assertEqual(v_low[...], 15) @@ -741,7 +741,7 @@ def test_variable_to_hijax_clean(self): print() print(v_low) assert not v_low.hijax - v_hi = nnx.vars_as(v_low, hijax=True) + v_hi = nnx.with_vars(v_low, hijax=True) v_hi[...] = jnp.array([2]) assert v_hi.hijax print(v_hi) @@ -757,7 +757,7 @@ def set(v_hi, a): assert v_hi[...] == 10 - v_low = nnx.vars_as(v_hi, hijax=False) + v_low = nnx.with_vars(v_hi, hijax=False) assert not v_low.hijax assert v_low[...] == 10 @@ -801,7 +801,7 @@ def __init__(self, din, dout, rngs: nnx.Rngs): assert not foo.w.hijax assert not foo.b.hijax - foo = nnx.vars_as(foo, hijax=True) + foo = nnx.with_vars(foo, hijax=True) assert foo.w.hijax assert foo.b.hijax @@ -949,7 +949,7 @@ def test_variable_copy_properties(self, hijax, ref): def test_variable_vars_as_properties(self, hijax, ref): v_original = nnx.Variable(jnp.array(1)) - v = nnx.vars_as(v_original, hijax=hijax, ref=ref) + v = nnx.with_vars(v_original, hijax=hijax, ref=ref) self.assertEqual(v.hijax, hijax) self.assertEqual(v.ref, ref) if hijax: @@ -1213,4 +1213,3 @@ def f(v, v2): if __name__ == '__main__': absltest.main() -