Replies: 2 comments
-
Digging in the docs, it seems like I may be using
But my function above is a lambda that accepts input as its first argument. So a better question than the one above may be "how do I thread |
Beta Was this translation helpful? Give feedback.
0 replies
-
Answering my own question :) class Simple(nn.Module):
@nn.compact
def __call__(self, x):
DenseVMapped = nn.vmap(
nn.Dense,
variable_axes={"params": 0},
split_rngs={"params": True},
in_axes=-2,
)
x = DenseVMapped(features=42, name="foo")(x)
return x I couldn't find something like this in the docs - maybe updating a lifted example to show how to thread arguments would help? At any rate, sorry for the noise! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all!
I'm a newb and I'm using
nn.vamp
for the first time. When I run:I'm getting:
As this is my first time working with
vmap
andflax
in general, I'm probably missing something obvious, but I can't quite figure it out from the error message. Any tips would be much appreciated! 🙏Beta Was this translation helpful? Give feedback.
All reactions