A Global Average Pooling Function? #2617
-
Having recently implemented a copy of the ResNet-RS model, I noticed that Flax does not have a function for global average pooling. So I ended up implementing my own function for use in flax models, just wondering if it may a useful addition to Flax? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
I am not very familiar with global average pooling, but doesn't this just reduce to taking the average over all inputs for each feature dimension? Could you please share your code? |
Beta Was this translation helpful? Give feedback.
-
Hey @codymlewis, global average pooling is just x = jnp.mean(x, axis=(1, 2)) # assuming (batch, height, width, channels) I don't see a real benefit of adding a |
Beta Was this translation helpful? Give feedback.
Hey @codymlewis, global average pooling is just
mean
over the spatial/temporal dims e.g:I don't see a real benefit of adding a
GobalAveragePooling
Module, currentlynn.max_pool
andnn.avg_pool
are just functions so it would make sense for this operation to also be a function,jnp.mean
is already that function so there would need to be some additional motivation to create a dedicated wrapper.