Skip to content

Commit

Permalink
Accelerate some adjustment for mixed precision (#1009)
Browse files Browse the repository at this point in the history
* Use accelerator.autocast when computing loss

According to the accelerate docs, loss computation should be performed
within the accelerator.autocast context manager:

https://huggingface.co/docs/accelerate/v0.21.0/en/quicktour#mixed-precision-training

I tested if this makes a difference by running the following notebook
with fp16 precision:

https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb

I found no difference at all: The runtime was practially the same and
the losses were identical. Still, I think it's better to have this than
not, as it is recommended by the accelerate docs.

* Update LR scheduler callback to work w/ accelerate

According to the accelerate docs:

https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training

the LR scheduler step should sometimes be skipped when using mixed
precision training because accelerate may skip update steps internally.
Therefore, I updated the LR scheduler callback to check if the net has
an accelerator and if it does, to check if a step is necessary.

This is actually quite hard to test because the necessity of stepping
depends on accelerate's internal logic, which we don't want to test, and
which might change in the future. Therefore, the added test just runs
training with accelerate, mixed precision, and some lr schedulers,
verifying that there is no error.

When running these tests + the normal lr scheduler tests locally on a
machine that supports fp16, I get 100% line coverage of lr_scheduler.py.
I think this is good enough.

* Non-functional clean ups related to lr schedulers

While working on the fixes in this PR, I also cleaned up some lr
scheduler code. These clean ups are non-functional.

1. We imported CyclicLR as TorchCyclicLR. I'm not sure why but it is
   somehow related to very old PyTorch versions we no longer support, so I
   removed this.
2. Fixed some indentations for conditional checks to improve
   readability.

* Reviewer comment: Simplify conditional code
  • Loading branch information
BenjaminBossan authored Aug 18, 2023
1 parent 07fc260 commit 312daaa
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 29 deletions.
63 changes: 44 additions & 19 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CyclicLR
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR

try:
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
except ImportError:
# Backward compatibility with torch >= 1.0 && < 1.1
TorchCyclicLR = None
from torch.optim.optimizer import Optimizer
from skorch.callbacks import Callback

Expand Down Expand Up @@ -142,6 +137,31 @@ def on_train_begin(self, net, **kwargs):
net, self.policy_, **self.kwargs
)

def _step(self, net, lr_scheduler, score=None):
"""Helper method to step the lr scheduler.
This takes care of two things:
1. If the lr scheduler is ReduceLROnPlateau, we need to pass the score.
2. If the net is uses AccelerateMixin, stepping has to be skipped in
certain conditions.
For more info on the latter, see:
https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training
"""
accelerator_maybe = getattr(net, 'accelerator', None)
accelerator_step_skipped = (
accelerator_maybe and accelerator_maybe.optimizer_step_was_skipped
)
if accelerator_step_skipped:
return

if score is None:
lr_scheduler.step()
else:
lr_scheduler.step(score)

def on_epoch_end(self, net, **kwargs):
if self.step_every != 'epoch':
return
Expand All @@ -158,31 +178,36 @@ def on_epoch_end(self, net, **kwargs):
"should be placed before the LRScheduler callback"
) from e

self.lr_scheduler_.step(score)
self._step(net, self.lr_scheduler_, score=score)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
else:
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
self.lr_scheduler_.step()
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)

def on_batch_end(self, net, training, **kwargs):
if not training or self.step_every != 'batch':
return
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record_batch(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
self.lr_scheduler_.step()
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record_batch(
self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)
self.batch_idx_ += 1

def _get_scheduler(self, net, policy, **scheduler_kwargs):
"""Return scheduler, based on indicated policy, with appropriate
parameters.
"""
if policy not in [ReduceLROnPlateau] and \
'last_epoch' not in scheduler_kwargs:
if (
(policy not in [ReduceLROnPlateau])
and ('last_epoch' not in scheduler_kwargs)
):
last_epoch = len(net.history) - 1
scheduler_kwargs['last_epoch'] = last_epoch

Expand Down
7 changes: 4 additions & 3 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,10 @@ def train_step(self, batch, **fit_params):
def train_step_single(self, batch, **fit_params):
self._set_training(True)
Xi, yi = unpack_data(batch)
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
self.accelerator.backward(loss)
with self.accelerator.autocast():
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
self.accelerator.backward(loss)
return {
'loss': loss,
'y_pred': y_pred,
Expand Down
13 changes: 6 additions & 7 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import pytest
import torch
from sklearn.base import clone
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand All @@ -12,7 +11,7 @@
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
from torch.optim.lr_scheduler import CyclicLR

from skorch import NeuralNetClassifier
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler
Expand All @@ -28,7 +27,7 @@ def test_simulate_lrs_epoch_step(self, policy):
expected = np.array([1.0, 1.0, 0.1, 0.1, 0.01, 0.01])
assert np.allclose(expected, lrs)

@pytest.mark.parametrize('policy', [TorchCyclicLR])
@pytest.mark.parametrize('policy', [CyclicLR])
def test_simulate_lrs_batch_step(self, policy):
lr_sch = LRScheduler(
policy, base_lr=1, max_lr=5, step_size_up=4, step_every='batch')
Expand Down Expand Up @@ -96,7 +95,7 @@ def test_lr_callback_steps_correctly(
assert lr_policy.lr_scheduler_.last_epoch == max_epochs

@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
(CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
])
def test_lr_callback_batch_steps_correctly(
self,
Expand Down Expand Up @@ -125,7 +124,7 @@ def test_lr_callback_batch_steps_correctly(
assert lr_policy.batch_idx_ == expected

@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
(CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
])
def test_lr_callback_batch_steps_correctly_fallback(
self,
Expand Down Expand Up @@ -177,7 +176,7 @@ def test_lr_scheduler_cloneable(self):

def test_lr_scheduler_set_params(self, classifier_module, classifier_data):
scheduler = LRScheduler(
TorchCyclicLR, base_lr=123, max_lr=999, step_every='batch')
CyclicLR, base_lr=123, max_lr=999, step_every='batch')
net = NeuralNetClassifier(
classifier_module,
max_epochs=0,
Expand Down Expand Up @@ -212,7 +211,7 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data
batch_size = 128

scheduler = LRScheduler(
TorchCyclicLR,
CyclicLR,
base_lr=1,
max_lr=5,
step_size_up=4,
Expand Down
48 changes: 48 additions & 0 deletions skorch/tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ class MockAccelerator:
def __init__(self):
self.device_placement = True
self.print = print
self.optimizer_step_was_skipped = False

def prepare(self, *args):
for arg in args:
Expand All @@ -826,6 +827,11 @@ def wait_for_everyone(self):
def accumulate(self, model):
yield

# pylint: disable=unused-argument
@contextmanager
def autocast(self, cache_enabled=False, autocast_handler=None):
yield

# pylint: disable=missing-docstring,arguments-differ
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
def get_iterator(self, *args, **kwargs):
Expand Down Expand Up @@ -950,6 +956,48 @@ def train_step(self, *args, **kwargs):
updated_expected = [False, False, True, False, False, True, True] * max_epochs
assert updated == updated_expected

@pytest.mark.parametrize('mixed_precision', ['no', 'fp16', 'bf16'])
@pytest.mark.parametrize('scheduler', ['ReduceLROnPlateau', 'StepLR'])
def test_lr_scheduler_with_accelerate(
self, net_cls, accelerator_cls, data, mixed_precision, scheduler
):
# This test only checks that lr schedulers work with accelerate mixed
# precision. The reason why this requires special handling is explained
# here:
# https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training
# There is no test for whether the lr scheduler actually steps correctly
# or not, as that would require knowledge of accelerate internals, which
# we don't want to rely on.
from accelerate.utils import is_bf16_available
from skorch.callbacks import LRScheduler

if (mixed_precision != 'no') and not torch.cuda.is_available():
pytest.skip('skipping AMP test because device does not support it')
if (mixed_precision == 'bf16') and not is_bf16_available():
pytest.skip('skipping bf16 test because device does not support it')

X, y = data[0][:100], data[1][:100]
max_epochs = 10

if scheduler == 'ReduceLROnPlateau':
lr_scheduler = LRScheduler(
policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
)
else:
lr_scheduler = LRScheduler(
policy=torch.optim.lr_scheduler.StepLR,
step_size=2,
step_every='batch',
)

accelerator = accelerator_cls()
net = net_cls(
accelerator=accelerator,
max_epochs=max_epochs,
callbacks=[lr_scheduler],
)
net.fit(X, y)


class MockHfApi:
"""Mock of huggingface_hub.HfAPI"""
Expand Down

0 comments on commit 312daaa

Please sign in to comment.