Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add tabulate #4493

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

[nnx] add tabulate #4493

wants to merge 1 commit into from

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jan 20, 2025

What does this PR do?

Adds nnx.tabulate with a similar API to linen.tabulate.

Example:

class Block(nnx.Module):
  def __init__(self, din, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dout, rngs=rngs)
    self.bn = nnx.BatchNorm(dout, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.dropout(self.bn(self.linear(x))))

class Foo(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    self.block1 = Block(32, 128, rngs=rngs)
    self.block2 = Block(128, 10, rngs=rngs)

  def __call__(self, x):
    return self.block2(self.block1(x))

foo = Foo(nnx.Rngs(0))

print(nnx.tabulate(foo, jnp.ones((1, 32))))
Screenshot 2025-01-21 at 12 55 36

@cgarciae cgarciae force-pushed the nnx-tabulate-2 branch 3 times, most recently from d47f9d2 to f0dcac7 Compare January 21, 2025 19:42
@cgarciae cgarciae marked this pull request as ready for review January 21, 2025 20:06
@cgarciae cgarciae force-pushed the nnx-tabulate-2 branch 2 times, most recently from 17c6ed7 to e245fab Compare January 24, 2025 00:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants