Skip to content

Is there a good example of sharing params across linen modules? #1144

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @levskaya:

Just pass in an instantiated module. Here's a tiny illustration:

class UsesModule(nn.Module):
  submodule: nn.Module
  @nn.compact
  def __call__(self, x):
    return self.submodule(x)

class Top(nn.Module):
  depth: int
  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(self.depth)  # instantiated
    submoduleA = UsesModule(dense)
    submoduleB = UsesModule(dense)
    x = submoduleA(x)
    x = submoduleB(x)
    return x 

key1, key2 = random.split(random.PRNGKey(0))
x = random.uniform(key1, (4,5))

variables = Top(5).init(key2, x)
y = Top(5).apply(variables, x)

# note that there's only a single place where "Dense_0" is defined, but 
# used multiple places.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant