Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Issues with saving/loading with accelerate #1008

Merged
merged 3 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
### Fixed

- Fixed a couple of issues when saving and loading parameters while using accelerate (via `AccelerateMixin`) in a multi-GPU setting (#1008)

## [0.14.0] - 2023-06-24

### Added
Expand Down
16 changes: 16 additions & 0 deletions examples/accelerate-multigpu/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Testing skorch with accelerate in multi GPU setting

This directory contains a couple of script used to test skorch with accelerate in a multi-GPU setting. The scripts cannot run as unit tests because they require a specific hardware setup not provided by the GitHub Action runners.

## `run-with-skorch.py`

The full history of this can be found here: https://github.com/skorch-dev/skorch/issues/944

There was an issue with using skorch in a multi-GPU setting with accelerate. After some searching, it turns out there were two problems:
Expand Down Expand Up @@ -36,3 +40,15 @@ tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```

## `run-save-load.py`

The context of this script is that there were issues with saving and loading when using `AccelerateMixin`. The provided script is to ensure that everything works as expected. Same as the first one, for a proper test, this script needs to run in a multi-GPU setting. For more information, check PR #1008.

Run the scripts like this:

```sh
accelerate launch run-save-load.py
```

The accelerate config is the same.
104 changes: 104 additions & 0 deletions examples/accelerate-multigpu/run-save-load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Check that saving and loading works with accelerate.

Especially, pay attention that both the initial model, as well as the loaded
model, could be either wrapped with accelerate or not, i.e. there are 4 possible
combinations.

"""

import numpy as np
import torch
from accelerate import Accelerator
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from torch import nn
from torch.distributed import TCPStore

from skorch import NeuralNetClassifier
from skorch.hf import AccelerateMixin
from skorch.history import DistributedHistory


PORT = 8080


class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense0 = nn.Linear(100, 2)
self.nonlin = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.dense0(X)
X = self.nonlin(X)
return X


# make use of accelerate by creating a class with the AccelerateMixin
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier):
pass


def get_accelerate_model(accelerator):
global PORT
PORT += 1

is_master = accelerator.is_main_process
world_size = accelerator.num_processes
rank = accelerator.local_process_index
store = TCPStore(
"127.0.0.1", port=PORT, world_size=world_size, is_master=is_master)
dist_history = DistributedHistory(
store=store, rank=rank, world_size=world_size)

return AcceleratedNeuralNetClassifier(
MyModule,
criterion=nn.CrossEntropyLoss,
accelerator=accelerator,
max_epochs=3,
lr=0.001,
history=dist_history,
)


def get_vanilla_model():
return NeuralNetClassifier(
MyModule,
criterion=nn.CrossEntropyLoss,
max_epochs=3,
lr=0.001,
)


def main(wrap_initial_model=True, wrap_loaded_model=True):
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0)
X = X.astype(np.float32)

accelerator = Accelerator()
model = get_accelerate_model(accelerator)
model.unwrap_after_train = True if wrap_initial_model else False
model.fit(X, y)

model.save_params(f_params="model_params.pt")
y_pred = model.predict(X)
accuracy_before = accuracy_score(y, y_pred)
print(f"Accuracy before loading: {accuracy_before}")

if wrap_loaded_model:
model_loaded = get_accelerate_model(accelerator).initialize()
else:
model_loaded = get_vanilla_model().initialize()

model_loaded.load_params(f_params="model_params.pt")
y_pred = model_loaded.predict(X)
accuracy_after = accuracy_score(y, y_pred)
print(f"Accuracy after loading: {accuracy_after}")

assert accuracy_before == accuracy_after


if __name__ == '__main__':
main(True, True)
main(True, False)
main(False, True)
main(False, False)
126 changes: 97 additions & 29 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ def __init__(
)
self.accelerator = accelerator
self.unwrap_after_train = unwrap_after_train
self._wrapped_with_accelerator = False

def _validate_params(self):
super()._validate_params()
Expand All @@ -934,53 +935,60 @@ def _validate_params(self):
"When device placement is performed by the accelerator, set device=None"
)

def _initialize_callbacks(self):
if self.callbacks__print_log__sink == 'auto':
print_func = getattr(self.accelerator, 'print', print)
self.callbacks__print_log__sink = print_func
super()._initialize_callbacks()
return self

def _initialize_criterion(self, *args, **kwargs):
super()._initialize_criterion(*args, **kwargs)
def _initialize_accelerator(self):
"""Prepare everything for use with accelerate"""
if self._wrapped_with_accelerator:
return self

with self._current_init_context('criterion'):
for name in self._criteria:
criterion = getattr(self, name + '_')
if isinstance(criterion, torch.nn.Module):
setattr(self, name + '_', self.accelerator.prepare(criterion))

return self

def _initialize_module(self, *args, **kwargs):
super()._initialize_module(*args, **kwargs)

with self._current_init_context('module'):
for name in self._modules:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
setattr(self, name + '_', self.accelerator.prepare(module))

return self

def _initialize_optimizer(self, *args, **kwargs):
super()._initialize_optimizer(*args, **kwargs)

with self._current_init_context('optimizer'):
for name in self._optimizers:
optimizer = getattr(self, name + '_')
if isinstance(optimizer, torch.optim.Optimizer):
setattr(self, name + '_', self.accelerator.prepare(optimizer))

return self

def initialize_callbacks(self, *args, **kwargs):
super().initialize_callbacks(*args, **kwargs)

for _, callback in self.callbacks_:
if isinstance(callback, LRScheduler):
callback.policy_ = self.accelerator.prepare(callback.policy_)

self._wrapped_with_accelerator = True
return self

def initialize(self):
"""Initializes all of its components and returns self."""
# this should be the same as the parent class, except for the one marked
# line
self.check_training_readiness()

self._initialize_virtual_params()
self._initialize_callbacks()
self._initialize_module()
self._initialize_criterion()
self._initialize_optimizer()
self._initialize_history()
self._initialize_accelerator() # <= added

self._validate_params()

self.initialized_ = True
return self

def _initialize_callbacks(self):
if self.callbacks__print_log__sink == 'auto':
print_func = getattr(self.accelerator, 'print', print)
self.callbacks__print_log__sink = print_func
super()._initialize_callbacks()
return self

def train_step(self, batch, **fit_params):
Expand Down Expand Up @@ -1021,17 +1029,23 @@ def _step_optimizer(self, step_fn):
optimizer = getattr(self, name + '_')
optimizer.step()

# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
super().on_train_end(net, X=X, y=y, **kwargs)
if not self.unwrap_after_train:
return self
def _unwrap_accelerator(self):
if not self._wrapped_with_accelerator:
return

for name in self._modules + self._criteria:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
orig = self.accelerator.unwrap_model(module, keep_fp32_wrapper=False)
setattr(self, name + '_', orig)
self._wrapped_with_accelerator = False

# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
self.accelerator.wait_for_everyone()
super().on_train_end(net, X=X, y=y, **kwargs)
if self.unwrap_after_train:
self._unwrap_accelerator()
return self

def evaluation_step(self, batch, training=False):
Expand All @@ -1042,6 +1056,59 @@ def evaluation_step(self, batch, training=False):
y_pred = self.accelerator.gather_for_metrics(output)
return y_pred

# pylint: disable=missing-function-docstring
def save_params(self, *args, **kwargs):
# has to be called even if not main process, or else there is a dead lock
self.accelerator.wait_for_everyone()

if not self._wrapped_with_accelerator:
if self.accelerator.is_main_process:
super().save_params(*args, **kwargs)
else:
# A potential issue with using accelerate is that a model that has
# been prepared with accelerate is wrapped, so that the keys of the
# state dict have an additional prefix, "module.". Therefore, when
# the model is unwrapped when saving and wrapped when loading, or
# vice versa, there will be a mismatch in the state dict keys. To
# prevent this, always unwrap before saving. During loading, in case
# the model is wrapped, this would result in an error, but we take
# care of unwrapping the model in that case during loading.
self._unwrap_accelerator()
try:
# note: although saving is only done on the main process,
# unwrapping+wrapping has to be done on all processes, or else
# there is an error, not sure why
if self.accelerator.is_main_process:
super().save_params(*args, **kwargs)
finally:
self._initialize_accelerator()

# pylint: disable=missing-function-docstring
def load_params(self, *args, **kwargs):
self.accelerator.wait_for_everyone()
prev_device = self.device

if not self._wrapped_with_accelerator:
super().load_params(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From reading the code, setting device=None will raise a warning and fallback to cpu. Does it make sense to temporary set self.device="cpu" before calling super().load_params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I made the change you suggested, and also wrapped the whole super().load_params(...) in a try ... finally to undo the change if something fails. It's not beautiful but better be safe than sorry.

The test was updated to check that there is no warning.

else:
# A potential issue with using accelerate is that a model that has
# been prepared with accelerate is wrapped, so that the keys of the
# state dict have an additional prefix, "module.". Therefore, when
# the model is unwrapped when saving and wrapped when loading, or
# vice versa, there will be a mismatch in the state dict keys. Here,
# we always unwrap the model first before loading (1st case). This
# would still result in an error in the 2nd case, but we take care
# of unwrapping the model in that case during saving.
self._unwrap_accelerator()
try:
super().load_params(*args, **kwargs)
finally:
self._initialize_accelerator()

# ensure that the device remains unchanged in case it was None before
# calling load_params
self.device = prev_device


class HfHubStorage:
"""Helper class that allows writing data to the Hugging Face Hub.
Expand Down Expand Up @@ -1213,6 +1280,7 @@ def flush(self):
if self.verbose:
self.sink(f"Uploaded file to {return_url}")

# pylint: disable=unused-argument
def close(self, *args):
self.flush()

Expand Down
Loading