diff --git a/README.rst b/README.rst index 14a220593..d5cdd21d9 100644 --- a/README.rst +++ b/README.rst @@ -27,18 +27,24 @@ Simple Usage :: - from pytorch_optimizer import Ranger21 + from pytorch_optimizer import AdamP ... model = YourModel() - optimizer = Ranger21(model.parameters()) + optimizer = AdamP(model.parameters()) ... - for input, output in data: - optimizer.zero_grad() - loss = loss_function(output, model(input)) - loss.backward() - optimizer.step() +or you can use optimizer loader, simply passing a name of the optimizer. + +:: + + from pytorch_optimizer import load_optimizers + + ... + model = YourModel() + opt = load_optimizers(optimizer='adamp', use_fp16=True) + optimizer = opt(model.parameters()) + ... Supported Optimizers -------------------- diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 95b4226e4..30fc13589 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -7,15 +7,17 @@ from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule from pytorch_optimizer.diffgrad import DiffGrad from pytorch_optimizer.diffrgrad import DiffRGrad +from pytorch_optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer from pytorch_optimizer.gc import centralize_gradient from pytorch_optimizer.lookahead import Lookahead from pytorch_optimizer.madgrad import MADGRAD +from pytorch_optimizer.optimizers import load_optimizers from pytorch_optimizer.pcgrad import PCGrad from pytorch_optimizer.radam import RAdam from pytorch_optimizer.ranger import Ranger from pytorch_optimizer.ranger21 import Ranger21 from pytorch_optimizer.sam import SAM from pytorch_optimizer.sgdp import SGDP -from pytorch_optimizer.utils import get_optimizer_parameters, normalize_gradient, unit_norm +from pytorch_optimizer.utils import clip_grad_norm, get_optimizer_parameters, normalize_gradient, unit_norm -__VERSION__ = '0.1.1' +__VERSION__ = '0.2.0' diff --git a/pytorch_optimizer/fp16.py b/pytorch_optimizer/fp16.py new file mode 100644 index 000000000..3164c6072 --- /dev/null +++ b/pytorch_optimizer/fp16.py @@ -0,0 +1,257 @@ +from typing import Dict, Optional + +import torch +from torch.optim import Optimizer + +from pytorch_optimizer.types import CLOSURE +from pytorch_optimizer.utils import clip_grad_norm, has_overflow + +__AUTHOR__ = 'Facebook' +__REFERENCE__ = 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py' + + +class DynamicLossScaler: + """Dynamically adjusts the loss scaling factor. + Dynamic loss scalers are important in mixed-precision training. + They help us avoid underflows and overflows in low-precision gradients. + + See here for information: + + + Shamelessly stolen and adapted from FairSeq. + + """ + + def __init__( + self, + init_scale: float = 2.0 ** 15, + scale_factor: float = 2.0, + scale_window: int = 2000, + tolerance: float = 0.00, + threshold: Optional[float] = None, + ): + """ + :param init_scale: Initial loss scale. + :param scale_factor: Factor by which to increase or decrease loss scale. + :param scale_window: If we do not experience overflow in scale_window iterations, + loss scale will increase by scale_factor. + :param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale + :param threshold: If not None, loss scale will decrease below this threshold + """ + self.loss_scale = init_scale + self.scale_factor = scale_factor + self.scale_window = scale_window + self.tolerance = tolerance + self.threshold = threshold + + self.iter: int = 0 + self.last_overflow_iter: int = -1 + self.last_rescale_iter: int = -1 + self.overflows_since_rescale: int = 0 + + def update_scale(self, overflow: bool): + """Update the loss scale. + If overflow exceeds our tolerance, we decrease the loss scale. If the number of + iterations since the last overflow exceeds the scale window, we increase the loss scale. + """ + iter_since_rescale: int = self.iter - self.last_rescale_iter + + if overflow: + # calculate how often we overflowed already + self.last_overflow_iter = self.iter + self.overflows_since_rescale += 1 + + pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale) + if pct_overflow >= self.tolerance: + # decrease loss scale by the scale factor + self.decrease_loss_scale() + + # reset iterations + self.last_rescale_iter = self.iter + self.overflows_since_rescale = 0 + elif (self.iter - self.last_overflow_iter) % self.scale_window == 0: + # increase the loss scale by scale factor + self.loss_scale *= self.scale_factor + self.last_rescale_iter = self.iter + + self.iter += 1 + + def decrease_loss_scale(self): + """Decrease the loss scale by self.scale_factor. + NOTE: the loss_scale will not go below self.threshold. + """ + self.loss_scale /= self.scale_factor + if self.threshold is not None: + self.loss_scale = max(self.loss_scale, self.threshold) + + +class SafeFP16Optimizer(Optimizer): + def __init__(self, optimizer, aggregate_gnorms: bool = False): + self.optimizer = optimizer + self.aggregate_gnorms = aggregate_gnorms + + self.fp16_params = self.get_parameters(optimizer) + self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False) + + # we want the optimizer to be tracking the fp32 parameters + if len(optimizer.param_groups) != 1: + # future implementers: this should hopefully be a matter of just + # iterating through the param groups and keeping track of the pointer + # through the fp32_params + raise NotImplementedError('[-] Need to implement the parameter group transfer.') + + optimizer.param_groups[0]['params'] = self.fp32_params + + self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15) + self.min_loss_scale: float = 2 ** -5 + self.needs_sync: bool = True + + @classmethod + def get_parameters(cls, optimizer: Optimizer): + params = [] + for pg in optimizer.param_groups: + params += list(pg['params']) + return params + + @classmethod + def build_fp32_params(cls, parameters, flatten: bool = True): + # create FP32 copy of parameters and grads + if flatten: + total_param_size = sum(p.data.numel() for p in parameters) + fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device) + + offset: int = 0 + for p in parameters: + numel = p.data.numel() + fp32_params[offset : offset + numel].copy_(p.data.view(-1)) + offset += numel + + fp32_params = torch.nn.Parameter(fp32_params) + fp32_params.grad = fp32_params.data.new(total_param_size) + return fp32_params + + fp32_params = [] + for p in parameters: + p32 = torch.nn.Parameter(p.data.float()) + p32.grad = torch.zeros_like(p32.data) + fp32_params.append(p32) + + return fp32_params + + def state_dict(self) -> Dict: + """Return the optimizer's state dict.""" + state_dict = self.optimizer.state_dict() + if self.scaler is not None: + state_dict['loss_scaler'] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict: Dict): + """Load an optimizer state dict. + In general we should prefer the configuration of the existing optimizer instance + (e.g., learning rate) over that found in the state_dict. This allows us to + resume training from a checkpoint using a new set of optimizer args. + """ + if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float): + self.scaler.loss_scale = state_dict['loss_scaler'] + self.optimizer.load_state_dict(state_dict) + + def backward(self, loss, update_main_grads: bool = False): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function + additionally dynamically scales the loss to avoid gradient underflow. + """ + if self.scaler is not None: + loss = loss * self.scaler.loss_scale + + loss.backward() + + self.needs_sync = True + if update_main_grads: + self.update_main_grads() + + def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0): + if self.needs_sync: + if self.scaler is not None: + # correct for dynamic loss scaler + multiply_grads /= self.scaler.loss_scale + + # copy FP16 grads to FP32 + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + + if p.grad is not None: + p32.grad.data.copy_(p.grad.data) + p32.grad.data.mul_(multiply_grads) + else: + p32.grad = torch.zeros_like(p.data, dtype=torch.float) + + self.needs_sync = False + + def multiply_grads(self, c): + """Multiplies grads by a constant c.""" + if self.needs_sync: + self.sync_fp16_grads_to_fp32(c) + else: + for p32 in self.fp32_params: + p32.grad.data.mul_(c) + + def update_main_grads(self): + self.sync_fp16_grads_to_fp32() + + def clip_main_grads(self, max_norm): + """Clips gradient norm and updates dynamic loss scaler.""" + self.sync_fp16_grads_to_fp32() + + grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_gnorms) + + # detect overflow and adjust loss scale + if self.scaler is not None: + overflow: bool = has_overflow(grad_norm) + prev_scale = self.scaler.loss_scale + self.scaler.update_scale(overflow) + if overflow: + self.zero_grad() + if self.scaler.loss_scale <= self.min_loss_scale: + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. + self.scaler.loss_scale = prev_scale + + raise FloatingPointError( + f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. ' + 'Try lowering the learning rate, using gradient clipping or ' + 'increasing the batch size.\n' + f'Overflow: setting loss scale to {self.scaler.loss_scale}' + ) + + return grad_norm + + def step(self, closure: CLOSURE = None): + """Performs a single optimization step.""" + self.sync_fp16_grads_to_fp32() + self.optimizer.step(closure) + + # copy FP32 params back into FP16 model + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.fp16_params: + p.grad = None + for p32 in self.fp32_params: + p32.grad.zero_() + self.needs_sync = False + + def get_lr(self) -> float: + return self.optimizer.get_lr() + + def set_lr(self, lr: float): + self.optimizer.set_lr(lr) + + @property + def loss_scale(self) -> float: + """Convenience function which TorchAgent calls to get current scale value.""" + return self.scaler.loss_scale diff --git a/pytorch_optimizer/optimizers.py b/pytorch_optimizer/optimizers.py new file mode 100644 index 000000000..701e2bce0 --- /dev/null +++ b/pytorch_optimizer/optimizers.py @@ -0,0 +1,48 @@ +from pytorch_optimizer.adabelief import AdaBelief +from pytorch_optimizer.adabound import AdaBound +from pytorch_optimizer.adahessian import AdaHessian +from pytorch_optimizer.adamp import AdamP +from pytorch_optimizer.diffgrad import DiffGrad +from pytorch_optimizer.diffrgrad import DiffRGrad +from pytorch_optimizer.fp16 import SafeFP16Optimizer +from pytorch_optimizer.madgrad import MADGRAD +from pytorch_optimizer.radam import RAdam +from pytorch_optimizer.ranger import Ranger +from pytorch_optimizer.ranger21 import Ranger21 +from pytorch_optimizer.sgdp import SGDP + + +def load_optimizers(optimizer: str, use_fp16: bool = False): + optimizer: str = optimizer.lower() + + if optimizer == 'adamp': + opt = AdamP + elif optimizer == 'ranger': + opt = Ranger + elif optimizer == 'ranger21': + opt = Ranger21 + elif optimizer == 'sgdp': + opt = SGDP + elif optimizer == 'radam': + opt = RAdam + elif optimizer == 'adabelief': + opt = AdaBelief + elif optimizer == 'adabound': + opt = AdaBound + elif optimizer == 'madgrad': + opt = MADGRAD + elif optimizer == 'diffrgrad': + opt = DiffRGrad + elif optimizer == 'diffgrad': + opt = DiffGrad + elif optimizer == 'diffgrad': + opt = DiffGrad + elif optimizer == 'adahessian': + opt = AdaHessian + else: + raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}') + + if use_fp16: + opt = SafeFP16Optimizer(opt) + + return opt diff --git a/pytorch_optimizer/utils.py b/pytorch_optimizer/utils.py index 9a58cfe55..8cf4fd9be 100644 --- a/pytorch_optimizer/utils.py +++ b/pytorch_optimizer/utils.py @@ -2,6 +2,7 @@ import torch from torch import nn +from torch.distributed import all_reduce from pytorch_optimizer.types import PARAMETERS @@ -10,6 +11,11 @@ def is_valid_parameters(parameters: PARAMETERS) -> bool: return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict) +def has_overflow(grad_norm: torch.Tensor) -> bool: + """Detect inf and NaN in grad_norm.""" + return grad_norm == float('inf') or grad_norm != grad_norm # pylint: disable=comparison-with-itself + + def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8) -> torch.Tensor: """normalize gradient with stddev :param x: torch.Tensor. gradient. @@ -27,6 +33,41 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo return x +def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> torch.Tensor: + """Clips grad norms. + During combination with FSDP, will also ensure that grad norms are aggregated + across all workers, since each worker only stores their shard of the gradients. + :param parameters: Parameters whose gradients we wish to clip + :param max_norm: Maximum norm we wish the gradients to have. If non-positive, then + we will not perform clipping. + :param sync: Boolean indicating whether we should aggregate across the distributed + group. Used only in combination with FSDP. + :returns: The gradient norm across all parameters, before clipping. + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # make sure any generators are expanded + parameters = list(parameters) + + # if syncing we need to manually perform the clipping so that we aggregate properly + if max_norm > 0 and not sync: + return torch.nn.utils.clip_grad_norm_(parameters, max_norm) + + norm_sq = sum(p.grad.data.norm() ** 2 for p in parameters if p.grad is not None) + if sync: + # also need to get the norms from all the other sharded works in FSDP + all_reduce(norm_sq) + + grad_norm = norm_sq.sqrt() + if max_norm > 0: + clip_coef = max_norm / (grad_norm + 1e-6) + for p in parameters: + p.grad.detach().mul_(clip_coef) + + return grad_norm + + def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor: keep_dim: bool = True dim: Optional[Union[int, Tuple[int, ...]]] = None