Is there a good example of sharing params across linen modules? #1144
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Original question by @johnpjf: "I have a single encoder model (that's basically a few network layer) and I pass it into K other modules that each have a separate learned image of embeddings, they look up an embedding and then apply the same encoder model." |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Mar 17, 2021
Replies: 1 comment
-
Answer by @levskaya: Just pass in an instantiated module. Here's a tiny illustration: class UsesModule(nn.Module):
submodule: nn.Module
@nn.compact
def __call__(self, x):
return self.submodule(x)
class Top(nn.Module):
depth: int
@nn.compact
def __call__(self, x):
dense = nn.Dense(self.depth) # instantiated
submoduleA = UsesModule(dense)
submoduleB = UsesModule(dense)
x = submoduleA(x)
x = submoduleB(x)
return x
key1, key2 = random.split(random.PRNGKey(0))
x = random.uniform(key1, (4,5))
variables = Top(5).init(key2, x)
y = Top(5).apply(variables, x)
# note that there's only a single place where "Dense_0" is defined, but
# used multiple places.
variables |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answer by @levskaya:
Just pass in an instantiated module. Here's a tiny illustration: