Skip to content

How to swap in a new variable tree for a given collection in a submodule in Flax/JAX? #998

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @levskaya:

I don't we've actually thought (beyond that "mutate" idea) about the best way to expose this "officially". We don't have a perfectly elegant bulk "collection update" call, we can definitely add something like that, but we should also just take care of a vjp wrapper that has some smart options for handling state.

One way to currently do it with the internal api:

class Bar(nn.Module):
  def setup(self):
    self.b = self.variable('vars', 'b', lambda: jnp.zeros((2,)))
    self.c = self.variable('vars', 'c', lambda: jnp.zeros((2,)))
  def __call__(self, x):
    return x + self.b.value + self.c.value

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    bar = Bar

Replies: 1 comment

Comment options

marcvanzee
Feb 5, 2021
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant