-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX: Issues with saving/loading with accelerate (#1008)
* FIX: Issues with saving/loading with accelerate Description There were a few issue with saving and loading parameters for an accelerated net (some only in a multi-GPU setting though): - load_params set the device if device=None to CPU, but we need None - not waiting for all processes to finish before saving parameters - all processes saving the parameters, when only main should - an issue with parameter names depending on the wrapping state Regarding the last point, the issue was that if the module(s) are wrapped with accelerate, the parameters have an additional prefix, "module.". So e.g. "dense0.weight" would become "module.dense0.weight". This could result in a key mismatch and error when loading a model. The solution to this is to always unwrap the net before saving and before loading. That way, the extra prefix is never present and there is no mismatch. A test was added to check this behavior, but since the GitHub CI does not offer multi-GPU support, it does not test for all failure cases. Therefore, I added a script, examples/accelerate-multigpu/run-save-load.py, that can be run on a multi-GPU setup to test the issue. This unit test checks the correct behavior on CPU, iterating through all 4 combinations of wrapping/not wrapping the initial/loaded model. Implementation The changes needed were often just a few lines that sync the processes or place a guard to only run on the main process. A few of these were quite unintuitive to me, so I added a comment for them. The one big change is that the preparation of the components by the accelerator is now moved to a separate method, _initialize_accelerator. This way, it is now possible to unwrap and re-wrap the model with a single method call each. Without that change, re-wrapping was only possible by calling net.initialize(), which would re-initialize everything, which is not desired. This change can be backwards incompatible: If a user saved the parameters of an accelerated net while it's still wrapped (not the default), and tries to load it into a wrapped net, it will no longer work. I think this case is rare enough that we can accept it. In the worst case, the user can still apply the state dict manually on the wrapped net. I did consider an alternative solution that would inspect the names of the keys in the state dict and try to determine from those if the loaded/current weights are from a wrapped model or not, and consequently rename the keys of the state dict. However, this method is unreliable and also not easy to implement with the current code, so I opted for the solution described above. Also note that this PR does _not_ fix potential issues that might occur during checkpointing of the model while it's training. For this, we need to use accelerator.{save_state,load_state}, see here: https://huggingface.co/docs/accelerate/usage_guides/checkpoint Probably this use case is best served with a separate checkpoint callback. * Changelog entry, references to the GH PR * Temporarily set device in load_params if necessary Reviewer feedback: To prevent a confusing warning message, temporarily set the device to 'cpu' when it was None during load_params. After loading has finished, set it back to the original value.
- Loading branch information
1 parent
e6023a1
commit 07fc260
Showing
5 changed files
with
303 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Check that saving and loading works with accelerate. | ||
Especially, pay attention that both the initial model, as well as the loaded | ||
model, could be either wrapped with accelerate or not, i.e. there are 4 possible | ||
combinations. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
from accelerate import Accelerator | ||
from sklearn.datasets import make_classification | ||
from sklearn.metrics import accuracy_score | ||
from torch import nn | ||
from torch.distributed import TCPStore | ||
|
||
from skorch import NeuralNetClassifier | ||
from skorch.hf import AccelerateMixin | ||
from skorch.history import DistributedHistory | ||
|
||
|
||
PORT = 8080 | ||
|
||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.dense0 = nn.Linear(100, 2) | ||
self.nonlin = nn.LogSoftmax(dim=-1) | ||
|
||
def forward(self, X): | ||
X = self.dense0(X) | ||
X = self.nonlin(X) | ||
return X | ||
|
||
|
||
# make use of accelerate by creating a class with the AccelerateMixin | ||
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier): | ||
pass | ||
|
||
|
||
def get_accelerate_model(accelerator): | ||
global PORT | ||
PORT += 1 | ||
|
||
is_master = accelerator.is_main_process | ||
world_size = accelerator.num_processes | ||
rank = accelerator.local_process_index | ||
store = TCPStore( | ||
"127.0.0.1", port=PORT, world_size=world_size, is_master=is_master) | ||
dist_history = DistributedHistory( | ||
store=store, rank=rank, world_size=world_size) | ||
|
||
return AcceleratedNeuralNetClassifier( | ||
MyModule, | ||
criterion=nn.CrossEntropyLoss, | ||
accelerator=accelerator, | ||
max_epochs=3, | ||
lr=0.001, | ||
history=dist_history, | ||
) | ||
|
||
|
||
def get_vanilla_model(): | ||
return NeuralNetClassifier( | ||
MyModule, | ||
criterion=nn.CrossEntropyLoss, | ||
max_epochs=3, | ||
lr=0.001, | ||
) | ||
|
||
|
||
def main(wrap_initial_model=True, wrap_loaded_model=True): | ||
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0) | ||
X = X.astype(np.float32) | ||
|
||
accelerator = Accelerator() | ||
model = get_accelerate_model(accelerator) | ||
model.unwrap_after_train = True if wrap_initial_model else False | ||
model.fit(X, y) | ||
|
||
model.save_params(f_params="model_params.pt") | ||
y_pred = model.predict(X) | ||
accuracy_before = accuracy_score(y, y_pred) | ||
print(f"Accuracy before loading: {accuracy_before}") | ||
|
||
if wrap_loaded_model: | ||
model_loaded = get_accelerate_model(accelerator).initialize() | ||
else: | ||
model_loaded = get_vanilla_model().initialize() | ||
|
||
model_loaded.load_params(f_params="model_params.pt") | ||
y_pred = model_loaded.predict(X) | ||
accuracy_after = accuracy_score(y, y_pred) | ||
print(f"Accuracy after loading: {accuracy_after}") | ||
|
||
assert accuracy_before == accuracy_after | ||
|
||
|
||
if __name__ == '__main__': | ||
main(True, True) | ||
main(True, False) | ||
main(False, True) | ||
main(False, False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.