Remat (aka checkpoint) with tied variables from custom variable collections #3673
Replies: 1 comment 1 reply
-
Hey! Sorry for the late reply. It seems you managed to solve the issue? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi.
I am working on a custom model using JAX/Flax, and I use a custom variable collection to be able to pass layer-specific configurations to the model, depending on the input. To be more precise, I pass some Attention Mask that can be defined layer-by-layer and is input-specific.
To save memory usage in cases where I want this info to be shared between layers, I tied the variables using
nn.map_variables
and a tie method inspired by this comment. It works as expected in most cases, but when I activate gradient checkpointing with thenn.partitioning.remat
function, wherever I train the model, there is a warning saying:I checked the code in Flax and I didn't find the origin of this warning (it doesn't throw an error). And I am not sure how this affect the variable values during the gradient checkpointing, and if it does affect, if there is a way to fix this.
Do you have any ideas to help me please?
Beta Was this translation helpful? Give feedback.
All reactions