-
Notifications
You must be signed in to change notification settings - Fork 390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Issues with saving/loading with accelerate #1008
Conversation
Description There were a few issue with saving and loading parameters for an accelerated net (some only in a multi-GPU setting though): - load_params set the device if device=None to CPU, but we need None - not waiting for all processes to finish before saving parameters - all processes saving the parameters, when only main should - an issue with parameter names depending on the wrapping state Regarding the last point, the issue was that if the module(s) are wrapped with accelerate, the parameters have an additional prefix, "module.". So e.g. "dense0.weight" would become "module.dense0.weight". This could result in a key mismatch and error when loading a model. The solution to this is to always unwrap the net before saving and before loading. That way, the extra prefix is never present and there is no mismatch. A test was added to check this behavior, but since the GitHub CI does not offer multi-GPU support, it does not test for all failure cases. Therefore, I added a script, examples/accelerate-multigpu/run-save-load.py, that can be run on a multi-GPU setup to test the issue. This unit test checks the correct behavior on CPU, iterating through all 4 combinations of wrapping/not wrapping the initial/loaded model. Implementation The changes needed were often just a few lines that sync the processes or place a guard to only run on the main process. A few of these were quite unintuitive to me, so I added a comment for them. The one big change is that the preparation of the components by the accelerator is now moved to a separate method, _initialize_accelerator. This way, it is now possible to unwrap and re-wrap the model with a single method call each. Without that change, re-wrapping was only possible by calling net.initialize(), which would re-initialize everything, which is not desired. This change can be backwards incompatible: If a user saved the parameters of an accelerated net while it's still wrapped (not the default), and tries to load it into a wrapped net, it will no longer work. I think this case is rare enough that we can accept it. In the worst case, the user can still apply the state dict manually on the wrapped net. I did consider an alternative solution that would inspect the names of the keys in the state dict and try to determine from those if the loaded/current weights are from a wrapped model or not, and consequently rename the keys of the state dict. However, this method is unreliable and also not easy to implement with the current code, so I opted for the solution described above. Also note that this PR does _not_ fix potential issues that might occur during checkpointing of the model while it's training. For this, we need to use accelerator.{save_state,load_state}, see here: https://huggingface.co/docs/accelerate/usage_guides/checkpoint Probably this use case is best served with a separate checkpoint callback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change can be backwards incompatible:
Given that the AccelerateMixin
is experimental, I am okay with breaking changes.
skorch/hf.py
Outdated
prev_device = self.device | ||
|
||
if not self._wrapped_with_accelerator: | ||
super().load_params(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From reading the code, setting device=None
will raise a warning and fallback to cpu
. Does it make sense to temporary set self.device="cpu"
before calling super().load_params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. I made the change you suggested, and also wrapped the whole super().load_params(...)
in a try ... finally
to undo the change if something fails. It's not beautiful but better be safe than sorry.
The test was updated to check that there is no warning.
Reviewer feedback: To prevent a confusing warning message, temporarily set the device to 'cpu' when it was None during load_params. After loading has finished, set it back to the original value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
There were a few issue with saving and loading parameters for an accelerated net (some only in a multi-GPU setting though):
load_params
set the device ifdevice=None
to CPU, but we needNone
for accelerateRegarding the last point, the issue was that if the module(s) are wrapped with accelerate, the parameters have an additional prefix,
"module."
. So e.g."dense0.weight"
would become"module.dense0.weight"
. This could result in a key mismatch and error when loading a model.The solution to this is to always unwrap the net before saving and before loading. That way, the extra prefix is never present and there is no mismatch.
A test was added to check this behavior, but since the GitHub CI does not offer multi-GPU support, it does not test for all failure cases. Therefore, I added a script,
examples/accelerate-multigpu/run-save-load.py
, that can be run on a multi-GPU setup to test the issue.This unit test checks the correct behavior on CPU, iterating through all 4 combinations of wrapping/not wrapping the initial/loaded model.
Implementation
The changes needed were often just a few lines that sync the processes or place a guard to only run on the main process. A few of these were quite unintuitive to me, so I added a comment for them.
The one big change is that the preparation of the components by the accelerator is now moved to a separate method,
_initialize_accelerator
. This way, it is now possible to unwrap and re-wrap the model with a single method call each. Without that change, re-wrapping was only possible by callingnet.initialize()
, which would re-initialize everything and is hence not desired.This change can be backwards incompatible: If a user saved the parameters of an accelerated net while it's still wrapped (not the default), and tries to load it into a wrapped net, it will no longer work. I think this case is rare enough that we can accept it. In the worst case, the user can still apply the state dict manually on the wrapped net.
I did consider an alternative solution that would inspect the names of the keys in the state dict and try to determine from those if the loaded/current weights are from a wrapped model or not, and consequently rename the keys of the state dict. However, this method is unreliable and also not easy to implement with the current code, so I opted for the solution described above.
Out of scope
Also note that this PR does not fix potential issues that might occur during checkpointing of the model while it's training with multiple GPUs. For this, we need to use
accelerator.{save_state,load_state}
, see here:https://huggingface.co/docs/accelerate/usage_guides/checkpoint
The usage would be distinct enough from the one of the existing
Checkpoint
callback that this use case is probably best served with a separate checkpoint callback. UsingCheckpoint
would be discouraged for that use case.