How to swap in a new variable tree for a given collection in a submodule in Flax/JAX? #998
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Original question by @mts42000: How can a parent module modify a variable (for a whole VariableDict and nested inner modules) in one of its submodules? |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Feb 5, 2021
Replies: 1 comment
-
Answer by @levskaya: I don't we've actually thought (beyond that "mutate" idea) about the best way to expose this "officially". We don't have a perfectly elegant bulk "collection update" call, we can definitely add something like that, but we should also just take care of a One way to currently do it with the internal api: class Bar(nn.Module):
def setup(self):
self.b = self.variable('vars', 'b', lambda: jnp.zeros((2,)))
self.c = self.variable('vars', 'c', lambda: jnp.zeros((2,)))
def __call__(self, x):
return x + self.b.value + self.c.value
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
bar = Bar()
# grab collection, must be mutable
col = bar.scope._mutable_collection('vars')
# do something to it
newcol = jax.tree_map(lambda x:x+1, col)
# set the mutated collection
for k in newcol:
bar.scope._variables['vars'][k] = newcol[k]
return bar(x)
k = random.PRNGKey(0)
x = jnp.zeros((2,))
foo = Foo()
p = foo.init(k, x)
y, vs = foo.apply(p, x, mutable=['vars']) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answer by @levskaya:
I don't we've actually thought (beyond that "mutate" idea) about the best way to expose this "officially". We don't have a perfectly elegant bulk "collection update" call, we can definitely add something like that, but we should also just take care of a
vjp
wrapper that has some smart options for handling state.One way to currently do it with the internal api: