-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #37 from kozistr/feature/fp16
[Feature] Support FP16 for all optimizers by utilizing wrapper class
- Loading branch information
Showing
5 changed files
with
363 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from torch.optim import Optimizer | ||
|
||
from pytorch_optimizer.types import CLOSURE | ||
from pytorch_optimizer.utils import clip_grad_norm, has_overflow | ||
|
||
__AUTHOR__ = 'Facebook' | ||
__REFERENCE__ = 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py' | ||
|
||
|
||
class DynamicLossScaler: | ||
"""Dynamically adjusts the loss scaling factor. | ||
Dynamic loss scalers are important in mixed-precision training. | ||
They help us avoid underflows and overflows in low-precision gradients. | ||
See here for information: | ||
<https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling> | ||
Shamelessly stolen and adapted from FairSeq. | ||
<https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py> | ||
""" | ||
|
||
def __init__( | ||
self, | ||
init_scale: float = 2.0 ** 15, | ||
scale_factor: float = 2.0, | ||
scale_window: int = 2000, | ||
tolerance: float = 0.00, | ||
threshold: Optional[float] = None, | ||
): | ||
""" | ||
:param init_scale: Initial loss scale. | ||
:param scale_factor: Factor by which to increase or decrease loss scale. | ||
:param scale_window: If we do not experience overflow in scale_window iterations, | ||
loss scale will increase by scale_factor. | ||
:param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale | ||
:param threshold: If not None, loss scale will decrease below this threshold | ||
""" | ||
self.loss_scale = init_scale | ||
self.scale_factor = scale_factor | ||
self.scale_window = scale_window | ||
self.tolerance = tolerance | ||
self.threshold = threshold | ||
|
||
self.iter: int = 0 | ||
self.last_overflow_iter: int = -1 | ||
self.last_rescale_iter: int = -1 | ||
self.overflows_since_rescale: int = 0 | ||
|
||
def update_scale(self, overflow: bool): | ||
"""Update the loss scale. | ||
If overflow exceeds our tolerance, we decrease the loss scale. If the number of | ||
iterations since the last overflow exceeds the scale window, we increase the loss scale. | ||
""" | ||
iter_since_rescale: int = self.iter - self.last_rescale_iter | ||
|
||
if overflow: | ||
# calculate how often we overflowed already | ||
self.last_overflow_iter = self.iter | ||
self.overflows_since_rescale += 1 | ||
|
||
pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale) | ||
if pct_overflow >= self.tolerance: | ||
# decrease loss scale by the scale factor | ||
self.decrease_loss_scale() | ||
|
||
# reset iterations | ||
self.last_rescale_iter = self.iter | ||
self.overflows_since_rescale = 0 | ||
elif (self.iter - self.last_overflow_iter) % self.scale_window == 0: | ||
# increase the loss scale by scale factor | ||
self.loss_scale *= self.scale_factor | ||
self.last_rescale_iter = self.iter | ||
|
||
self.iter += 1 | ||
|
||
def decrease_loss_scale(self): | ||
"""Decrease the loss scale by self.scale_factor. | ||
NOTE: the loss_scale will not go below self.threshold. | ||
""" | ||
self.loss_scale /= self.scale_factor | ||
if self.threshold is not None: | ||
self.loss_scale = max(self.loss_scale, self.threshold) | ||
|
||
|
||
class SafeFP16Optimizer(Optimizer): | ||
def __init__(self, optimizer, aggregate_gnorms: bool = False): | ||
self.optimizer = optimizer | ||
self.aggregate_gnorms = aggregate_gnorms | ||
|
||
self.fp16_params = self.get_parameters(optimizer) | ||
self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False) | ||
|
||
# we want the optimizer to be tracking the fp32 parameters | ||
if len(optimizer.param_groups) != 1: | ||
# future implementers: this should hopefully be a matter of just | ||
# iterating through the param groups and keeping track of the pointer | ||
# through the fp32_params | ||
raise NotImplementedError('[-] Need to implement the parameter group transfer.') | ||
|
||
optimizer.param_groups[0]['params'] = self.fp32_params | ||
|
||
self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15) | ||
self.min_loss_scale: float = 2 ** -5 | ||
self.needs_sync: bool = True | ||
|
||
@classmethod | ||
def get_parameters(cls, optimizer: Optimizer): | ||
params = [] | ||
for pg in optimizer.param_groups: | ||
params += list(pg['params']) | ||
return params | ||
|
||
@classmethod | ||
def build_fp32_params(cls, parameters, flatten: bool = True): | ||
# create FP32 copy of parameters and grads | ||
if flatten: | ||
total_param_size = sum(p.data.numel() for p in parameters) | ||
fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device) | ||
|
||
offset: int = 0 | ||
for p in parameters: | ||
numel = p.data.numel() | ||
fp32_params[offset : offset + numel].copy_(p.data.view(-1)) | ||
offset += numel | ||
|
||
fp32_params = torch.nn.Parameter(fp32_params) | ||
fp32_params.grad = fp32_params.data.new(total_param_size) | ||
return fp32_params | ||
|
||
fp32_params = [] | ||
for p in parameters: | ||
p32 = torch.nn.Parameter(p.data.float()) | ||
p32.grad = torch.zeros_like(p32.data) | ||
fp32_params.append(p32) | ||
|
||
return fp32_params | ||
|
||
def state_dict(self) -> Dict: | ||
"""Return the optimizer's state dict.""" | ||
state_dict = self.optimizer.state_dict() | ||
if self.scaler is not None: | ||
state_dict['loss_scaler'] = self.scaler.loss_scale | ||
return state_dict | ||
|
||
def load_state_dict(self, state_dict: Dict): | ||
"""Load an optimizer state dict. | ||
In general we should prefer the configuration of the existing optimizer instance | ||
(e.g., learning rate) over that found in the state_dict. This allows us to | ||
resume training from a checkpoint using a new set of optimizer args. | ||
""" | ||
if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float): | ||
self.scaler.loss_scale = state_dict['loss_scaler'] | ||
self.optimizer.load_state_dict(state_dict) | ||
|
||
def backward(self, loss, update_main_grads: bool = False): | ||
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. | ||
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function | ||
additionally dynamically scales the loss to avoid gradient underflow. | ||
""" | ||
if self.scaler is not None: | ||
loss = loss * self.scaler.loss_scale | ||
|
||
loss.backward() | ||
|
||
self.needs_sync = True | ||
if update_main_grads: | ||
self.update_main_grads() | ||
|
||
def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0): | ||
if self.needs_sync: | ||
if self.scaler is not None: | ||
# correct for dynamic loss scaler | ||
multiply_grads /= self.scaler.loss_scale | ||
|
||
# copy FP16 grads to FP32 | ||
for p, p32 in zip(self.fp16_params, self.fp32_params): | ||
if not p.requires_grad: | ||
continue | ||
|
||
if p.grad is not None: | ||
p32.grad.data.copy_(p.grad.data) | ||
p32.grad.data.mul_(multiply_grads) | ||
else: | ||
p32.grad = torch.zeros_like(p.data, dtype=torch.float) | ||
|
||
self.needs_sync = False | ||
|
||
def multiply_grads(self, c): | ||
"""Multiplies grads by a constant c.""" | ||
if self.needs_sync: | ||
self.sync_fp16_grads_to_fp32(c) | ||
else: | ||
for p32 in self.fp32_params: | ||
p32.grad.data.mul_(c) | ||
|
||
def update_main_grads(self): | ||
self.sync_fp16_grads_to_fp32() | ||
|
||
def clip_main_grads(self, max_norm): | ||
"""Clips gradient norm and updates dynamic loss scaler.""" | ||
self.sync_fp16_grads_to_fp32() | ||
|
||
grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_gnorms) | ||
|
||
# detect overflow and adjust loss scale | ||
if self.scaler is not None: | ||
overflow: bool = has_overflow(grad_norm) | ||
prev_scale = self.scaler.loss_scale | ||
self.scaler.update_scale(overflow) | ||
if overflow: | ||
self.zero_grad() | ||
if self.scaler.loss_scale <= self.min_loss_scale: | ||
# Use FloatingPointError as an uncommon error that parent | ||
# functions can safely catch to stop training. | ||
self.scaler.loss_scale = prev_scale | ||
|
||
raise FloatingPointError( | ||
f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. ' | ||
'Try lowering the learning rate, using gradient clipping or ' | ||
'increasing the batch size.\n' | ||
f'Overflow: setting loss scale to {self.scaler.loss_scale}' | ||
) | ||
|
||
return grad_norm | ||
|
||
def step(self, closure: CLOSURE = None): | ||
"""Performs a single optimization step.""" | ||
self.sync_fp16_grads_to_fp32() | ||
self.optimizer.step(closure) | ||
|
||
# copy FP32 params back into FP16 model | ||
for p, p32 in zip(self.fp16_params, self.fp32_params): | ||
if not p.requires_grad: | ||
continue | ||
p.data.copy_(p32.data) | ||
|
||
def zero_grad(self): | ||
"""Clears the gradients of all optimized parameters.""" | ||
for p in self.fp16_params: | ||
p.grad = None | ||
for p32 in self.fp32_params: | ||
p32.grad.zero_() | ||
self.needs_sync = False | ||
|
||
def get_lr(self) -> float: | ||
return self.optimizer.get_lr() | ||
|
||
def set_lr(self, lr: float): | ||
self.optimizer.set_lr(lr) | ||
|
||
@property | ||
def loss_scale(self) -> float: | ||
"""Convenience function which TorchAgent calls to get current scale value.""" | ||
return self.scaler.loss_scale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from pytorch_optimizer.adabelief import AdaBelief | ||
from pytorch_optimizer.adabound import AdaBound | ||
from pytorch_optimizer.adahessian import AdaHessian | ||
from pytorch_optimizer.adamp import AdamP | ||
from pytorch_optimizer.diffgrad import DiffGrad | ||
from pytorch_optimizer.diffrgrad import DiffRGrad | ||
from pytorch_optimizer.fp16 import SafeFP16Optimizer | ||
from pytorch_optimizer.madgrad import MADGRAD | ||
from pytorch_optimizer.radam import RAdam | ||
from pytorch_optimizer.ranger import Ranger | ||
from pytorch_optimizer.ranger21 import Ranger21 | ||
from pytorch_optimizer.sgdp import SGDP | ||
|
||
|
||
def load_optimizers(optimizer: str, use_fp16: bool = False): | ||
optimizer: str = optimizer.lower() | ||
|
||
if optimizer == 'adamp': | ||
opt = AdamP | ||
elif optimizer == 'ranger': | ||
opt = Ranger | ||
elif optimizer == 'ranger21': | ||
opt = Ranger21 | ||
elif optimizer == 'sgdp': | ||
opt = SGDP | ||
elif optimizer == 'radam': | ||
opt = RAdam | ||
elif optimizer == 'adabelief': | ||
opt = AdaBelief | ||
elif optimizer == 'adabound': | ||
opt = AdaBound | ||
elif optimizer == 'madgrad': | ||
opt = MADGRAD | ||
elif optimizer == 'diffrgrad': | ||
opt = DiffRGrad | ||
elif optimizer == 'diffgrad': | ||
opt = DiffGrad | ||
elif optimizer == 'diffgrad': | ||
opt = DiffGrad | ||
elif optimizer == 'adahessian': | ||
opt = AdaHessian | ||
else: | ||
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') | ||
|
||
if use_fp16: | ||
opt = SafeFP16Optimizer(opt) | ||
|
||
return opt |
Oops, something went wrong.