-
Notifications
You must be signed in to change notification settings - Fork 390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Issues with saving/loading with accelerate #1008
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Check that saving and loading works with accelerate. | ||
|
||
Especially, pay attention that both the initial model, as well as the loaded | ||
model, could be either wrapped with accelerate or not, i.e. there are 4 possible | ||
combinations. | ||
|
||
""" | ||
|
||
import numpy as np | ||
import torch | ||
from accelerate import Accelerator | ||
from sklearn.datasets import make_classification | ||
from sklearn.metrics import accuracy_score | ||
from torch import nn | ||
from torch.distributed import TCPStore | ||
|
||
from skorch import NeuralNetClassifier | ||
from skorch.hf import AccelerateMixin | ||
from skorch.history import DistributedHistory | ||
|
||
|
||
PORT = 8080 | ||
|
||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.dense0 = nn.Linear(100, 2) | ||
self.nonlin = nn.LogSoftmax(dim=-1) | ||
|
||
def forward(self, X): | ||
X = self.dense0(X) | ||
X = self.nonlin(X) | ||
return X | ||
|
||
|
||
# make use of accelerate by creating a class with the AccelerateMixin | ||
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier): | ||
pass | ||
|
||
|
||
def get_accelerate_model(accelerator): | ||
global PORT | ||
PORT += 1 | ||
|
||
is_master = accelerator.is_main_process | ||
world_size = accelerator.num_processes | ||
rank = accelerator.local_process_index | ||
store = TCPStore( | ||
"127.0.0.1", port=PORT, world_size=world_size, is_master=is_master) | ||
dist_history = DistributedHistory( | ||
store=store, rank=rank, world_size=world_size) | ||
|
||
return AcceleratedNeuralNetClassifier( | ||
MyModule, | ||
criterion=nn.CrossEntropyLoss, | ||
accelerator=accelerator, | ||
max_epochs=3, | ||
lr=0.001, | ||
history=dist_history, | ||
) | ||
|
||
|
||
def get_vanilla_model(): | ||
return NeuralNetClassifier( | ||
MyModule, | ||
criterion=nn.CrossEntropyLoss, | ||
max_epochs=3, | ||
lr=0.001, | ||
) | ||
|
||
|
||
def main(wrap_initial_model=True, wrap_loaded_model=True): | ||
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0) | ||
X = X.astype(np.float32) | ||
|
||
accelerator = Accelerator() | ||
model = get_accelerate_model(accelerator) | ||
model.unwrap_after_train = True if wrap_initial_model else False | ||
model.fit(X, y) | ||
|
||
model.save_params(f_params="model_params.pt") | ||
y_pred = model.predict(X) | ||
accuracy_before = accuracy_score(y, y_pred) | ||
print(f"Accuracy before loading: {accuracy_before}") | ||
|
||
if wrap_loaded_model: | ||
model_loaded = get_accelerate_model(accelerator).initialize() | ||
else: | ||
model_loaded = get_vanilla_model().initialize() | ||
|
||
model_loaded.load_params(f_params="model_params.pt") | ||
y_pred = model_loaded.predict(X) | ||
accuracy_after = accuracy_score(y, y_pred) | ||
print(f"Accuracy after loading: {accuracy_after}") | ||
|
||
assert accuracy_before == accuracy_after | ||
|
||
|
||
if __name__ == '__main__': | ||
main(True, True) | ||
main(True, False) | ||
main(False, True) | ||
main(False, False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From reading the code, setting
device=None
will raise a warning and fallback tocpu
. Does it make sense to temporary setself.device="cpu"
before callingsuper().load_params
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. I made the change you suggested, and also wrapped the whole
super().load_params(...)
in atry ... finally
to undo the change if something fails. It's not beautiful but better be safe than sorry.The test was updated to check that there is no warning.