Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Issues with saving/loading with accelerate #1008

Merged
merged 3 commits into from
Aug 18, 2023

Conversation

BenjaminBossan
Copy link
Collaborator

Description

There were a few issue with saving and loading parameters for an accelerated net (some only in a multi-GPU setting though):

  • load_params set the device if device=None to CPU, but we need None for accelerate
  • not waiting for all processes to finish before saving parameters
  • all processes saving the parameters, when only the main process should
  • an issue with parameter names depending on the wrapping state

Regarding the last point, the issue was that if the module(s) are wrapped with accelerate, the parameters have an additional prefix, "module.". So e.g. "dense0.weight" would become "module.dense0.weight". This could result in a key mismatch and error when loading a model.

The solution to this is to always unwrap the net before saving and before loading. That way, the extra prefix is never present and there is no mismatch.

A test was added to check this behavior, but since the GitHub CI does not offer multi-GPU support, it does not test for all failure cases. Therefore, I added a script, examples/accelerate-multigpu/run-save-load.py, that can be run on a multi-GPU setup to test the issue.

This unit test checks the correct behavior on CPU, iterating through all 4 combinations of wrapping/not wrapping the initial/loaded model.

Implementation

The changes needed were often just a few lines that sync the processes or place a guard to only run on the main process. A few of these were quite unintuitive to me, so I added a comment for them.

The one big change is that the preparation of the components by the accelerator is now moved to a separate method, _initialize_accelerator. This way, it is now possible to unwrap and re-wrap the model with a single method call each. Without that change, re-wrapping was only possible by calling net.initialize(), which would re-initialize everything and is hence not desired.

This change can be backwards incompatible: If a user saved the parameters of an accelerated net while it's still wrapped (not the default), and tries to load it into a wrapped net, it will no longer work. I think this case is rare enough that we can accept it. In the worst case, the user can still apply the state dict manually on the wrapped net.

I did consider an alternative solution that would inspect the names of the keys in the state dict and try to determine from those if the loaded/current weights are from a wrapped model or not, and consequently rename the keys of the state dict. However, this method is unreliable and also not easy to implement with the current code, so I opted for the solution described above.

Out of scope

Also note that this PR does not fix potential issues that might occur during checkpointing of the model while it's training with multiple GPUs. For this, we need to use accelerator.{save_state,load_state}, see here:

https://huggingface.co/docs/accelerate/usage_guides/checkpoint

The usage would be distinct enough from the one of the existing Checkpoint callback that this use case is probably best served with a separate checkpoint callback. Using Checkpoint would be discouraged for that use case.

Description

There were a few issue with saving and loading parameters for an
accelerated net (some only in a multi-GPU setting though):

- load_params set the device if device=None to CPU, but we need None
- not waiting for all processes to finish before saving parameters
- all processes saving the parameters, when only main should
- an issue with parameter names depending on the wrapping state

Regarding the last point, the issue was that if the module(s) are
wrapped with accelerate, the parameters have an additional prefix,
"module.". So e.g. "dense0.weight" would become "module.dense0.weight".
This could result in a key mismatch and error when loading a model.

The solution to this is to always unwrap the net before saving and
before loading. That way, the extra prefix is never present and there is
no mismatch.

A test was added to check this behavior, but since the GitHub CI does
not offer multi-GPU support, it does not test for all failure cases.
Therefore, I added a script,
examples/accelerate-multigpu/run-save-load.py, that can be run on a
multi-GPU setup to test the issue.

This unit test checks the correct behavior on CPU, iterating through all
4 combinations of wrapping/not wrapping the initial/loaded model.

Implementation

The changes needed were often just a few lines that sync the processes
or place a guard to only run on the main process. A few of these were
quite unintuitive to me, so I added a comment for them.

The one big change is that the preparation of the components by the
accelerator is now moved to a separate method, _initialize_accelerator.
This way, it is now possible to unwrap and re-wrap the model with a
single method call each. Without that change, re-wrapping was only
possible by calling net.initialize(), which would re-initialize
everything, which is not desired.

This change can be backwards incompatible: If a user saved the
parameters of an accelerated net while it's still wrapped (not the
default), and tries to load it into a wrapped net, it will no longer
work. I think this case is rare enough that we can accept it. In the
worst case, the user can still apply the state dict manually on the
wrapped net.

I did consider an alternative solution that would inspect the names of
the keys in the state dict and try to determine from those if the
loaded/current weights are from a wrapped model or not, and consequently
rename the keys of the state dict. However, this method is unreliable
and also not easy to implement with the current code, so I opted for the
solution described above.

Also note that this PR does _not_ fix potential issues that might occur
during checkpointing of the model while it's training. For this, we need
to use accelerator.{save_state,load_state}, see here:

https://huggingface.co/docs/accelerate/usage_guides/checkpoint

Probably this use case is best served with a separate checkpoint
callback.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change can be backwards incompatible:

Given that the AccelerateMixin is experimental, I am okay with breaking changes.

skorch/hf.py Outdated
prev_device = self.device

if not self._wrapped_with_accelerator:
super().load_params(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From reading the code, setting device=None will raise a warning and fallback to cpu. Does it make sense to temporary set self.device="cpu" before calling super().load_params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I made the change you suggested, and also wrapped the whole super().load_params(...) in a try ... finally to undo the change if something fails. It's not beautiful but better be safe than sorry.

The test was updated to check that there is no warning.

Reviewer feedback: To prevent a confusing warning message, temporarily
set the device to 'cpu' when it was None during load_params. After
loading has finished, set it back to the original value.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan merged commit 07fc260 into master Aug 18, 2023
13 checks passed
@BenjaminBossan BenjaminBossan deleted the FIX-save-and-load-with-accelerate branch August 30, 2023 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants