Batch size during init #18
Replies: 1 comment
-
Hi @shashank2000, for the initialization, the batch size does not matter, since the dimension of the weights does not depend on it. dense = nn.Dense(features=4) you are telling Flax that your layer should have 4 output neurons. Flax still needs to know how many input neurons you have, in order to determine the dimension of the weights. params = dense.init(key, jnp.ones((1, 3))) which will tell Flax that there will be 3 input neurons. params = dense.init(key, jnp.ones((10, 3))) It will be the same, as long as the number of features are the same. Hope that helps! |
Beta Was this translation helpful? Give feedback.
-
Hey @matthias-wright
Quick question - I was wondering why the ResNet model was initialized with a batch size of 1; more generally, if I were to use one of the pretrained models here and then have a linear layer, what would the initialization process look like?
I'm assuming I'd define a new linen module consisting of, say, a ResNet and a
Dense
layer. Then I would initialize that module with somex
with batch size 1? Would this be true even if the batch size in my downstream model is >1? I'm a little bit confused because in the train loop in the same resnet file above it seems like there are indeed batches being passed through the forward pass of the model withapply
- but theinit
call for the same model has anx
with the first dim being just1
.Let me know if that wasn't clear, and thanks very much in advance.
Beta Was this translation helpful? Give feedback.
All reactions