From 4bb541ec09c26125c66d1e8de406991c140bad5b Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 19 Feb 2025 03:03:03 +0400 Subject: [PATCH 1/9] prodigy is ready, add mars and adopt in the morning --- megatron/core/optimizer/__init__.py | 28 ++ megatron/core/optimizer/adopt.py | 116 +++++++ megatron/core/optimizer/distrib_optimizer.py | 8 +- megatron/core/optimizer/mars.py | 316 +++++++++++++++++++ megatron/core/optimizer/optimizer_config.py | 19 ++ megatron/core/optimizer/prodigy.py | 274 ++++++++++++++++ megatron/training/arguments.py | 12 +- 7 files changed, 771 insertions(+), 2 deletions(-) create mode 100644 megatron/core/optimizer/adopt.py create mode 100644 megatron/core/optimizer/mars.py create mode 100644 megatron/core/optimizer/prodigy.py diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 7821094702f..cffaf27cc05 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -24,6 +24,9 @@ from torch.optim import AdamW as Adam, SGD from .ademamix import AdEMAMix +from .prodigy import Prodigy +from .mars import MARS +from .adopt import ADOPT from megatron.core import mpu @@ -326,6 +329,31 @@ def init_state_fn(opt, config=None): else: opt.initialize_state(p) + elif config.optimizer == 'prodigy': + kwargs = { + "params": param_groups, + "lr": config.lr, + "weight_decay": config.weight_decay, + "betas": (config.adam_beta1, config.adam_beta2), + "beta3": config.prodigy_beta3, + "decouple": config.prodigy_decouple, + "use_bias_correction": config.prodigy_use_bias_correction, + "safeguard_warmup": config.prodigy_safeguard_warmup, + "fsdp_in_use": config.prodigy_fsdp_in_use, + } + + optimizer = Prodigy(**kwargs) + + def init_state_fn(opt, config=None): + for group in opt.param_groups: + for p in group['params']: + if 'step' not in opt.state[p]: + opt.state[p]['step'] = 0 + opt.state[p]['s'] = torch.zeros_like(p.data).detach() + opt.state[p]['p0'] = p.detach().clone() + opt.state[p]['exp_avg'] = torch.zeros_like(p.data).detach() + opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data).detach() + elif config.optimizer == 'sgd': optimizer = SGD( diff --git a/megatron/core/optimizer/adopt.py b/megatron/core/optimizer/adopt.py new file mode 100644 index 00000000000..f5db8021661 --- /dev/null +++ b/megatron/core/optimizer/adopt.py @@ -0,0 +1,116 @@ +""" +Here is an original implementation of ADOPT. +Source: https://github.com/iShohei220/adopt +""" + +import torch + + +def exists(val): + return val is not None + + +from typing import Callable, Optional, Tuple + +import torch + + +class ADOPT(torch.optim.Optimizer): + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.9999), + eps: float = 1e-6, + clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25, + weight_decay: float = 0.0, + decouple: bool = True, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.clip_lambda = clip_lambda + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, decouple=decouple + ) + super().__init__(params, defaults) + + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group["betas"] + + for p in group["params"]: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError("ADOPT does not support sparse gradients") + + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + + if len(state) == 0: + state["step"] = torch.tensor(0.0) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + state_steps.append(state["step"]) + + for i, param in enumerate(params_with_grad): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + step = int(step_t.item()) + + if step == 0: + exp_avg_sq.addcmul_(grad, grad) + step_t += 1 + continue + + if group["weight_decay"] != 0: + if group["decouple"]: + param.data.mul_(1 - group["lr"] * group["weight_decay"]) + else: + grad = grad.add(param, alpha=group["weight_decay"]) + + denom = torch.clamp(exp_avg_sq.sqrt(), group["eps"]) + normed_grad = grad.div(denom) + + if self.clip_lambda is not None: + clip = self.clip_lambda(step) + normed_grad.clamp_(-clip, clip) + + exp_avg.lerp_(normed_grad, 1 - beta1) + param.data.add_(exp_avg, alpha=-group["lr"]) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + step_t += 1 + + return loss \ No newline at end of file diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 252ee9646cc..f935eb39dbd 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -22,6 +22,7 @@ HAVE_APEX_OR_TE = False from .ademamix import AdEMAMix +from .prodigy import Prodigy from .. import tensor_parallel from ..config_logger import has_config_logger_enabled, log_config_to_disk @@ -501,8 +502,11 @@ def __init__( self.optimizer_keys = ("param", "exp_avg_slow", "exp_avg_fast", "exp_avg_sq") else: self.optimizer_keys = ("param", "exp_avg_slow", "exp_avg_sq") + elif isinstance(optimizer, Prodigy): + self.optimizer_name = 'prodigy' + self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "s", "p0") else: - raise Exception(f"Unrecognized optimizer {type(optimizer)}, only Adam and AdEMAMix are supported for now.") + raise Exception(f"Unrecognized optimizer {type(optimizer)}.") # when freezing sub-models we have no real optimizer # but still need a stub DistributedOptimizer class @@ -714,6 +718,8 @@ def load_state_dict(self, state_dict): tensors = {"exp_avg_slow": init_shard(), "exp_avg_fast": init_shard(), "exp_avg_sq": init_shard()} else: # beta1 == 0 tensors = {"exp_avg_slow": init_shard(), "exp_avg_sq": init_shard()} + elif self.optimizer_name == 'prodigy': + tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "s": init_shard(), "p0": init_shard()} if self.config.use_precision_aware_optimizer: tensors["master_param"] = init_shard() state_dict_state.append((state_order, tensors)) diff --git a/megatron/core/optimizer/mars.py b/megatron/core/optimizer/mars.py new file mode 100644 index 00000000000..8adb528524e --- /dev/null +++ b/megatron/core/optimizer/mars.py @@ -0,0 +1,316 @@ +""" +Here is an original implementation of MARS. +Source: https://github.com/AGI-Arena/MARS +""" + +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +import math + +import torch + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(0) > G.size(1): + X = X.T + return X + + +def exists(val): + return val is not None + + +def update_fn( + p, + grad, + exp_avg, + exp_avg_sq, + lr, + wd, + beta1, + beta2, + last_grad, + eps, + amsgrad, + max_exp_avg_sq, + step, + gamma, + mars_type, + is_grad_2d, + optimize_1d, + lr_1d_factor, + betas_1d, + weight_decay_1d, +): + # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para + if optimize_1d or is_grad_2d: + c_t = (grad - last_grad).mul(gamma * (beta1 / (1.0 - beta1))).add(grad) + c_t_norm = torch.norm(c_t) + if c_t_norm > 1.0: + c_t = c_t / c_t_norm + exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1) + if (mars_type == "mars-adamw") or ( + mars_type == "mars-shampoo" and not is_grad_2d + ): + exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2) + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + if amsgrad: + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + denom = ( + max_exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + else: + denom = ( + exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom)) + elif mars_type == "mars-lion": + real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign()) + elif mars_type == "mars-shampoo" and is_grad_2d: + factor = max(1, grad.size(0) / grad.size(1)) ** 0.5 + real_update_tmp = ( + zeropower_via_newtonschulz5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps) + .mul(factor) + .add(wd, p.data) + .mul(-lr) + ) + p.data.add_(real_update_tmp) + else: + beta1_1d, beta2_1d = betas_1d + exp_avg.mul_(beta1_1d).add_(grad, alpha=1 - beta1_1d) + exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1 - beta2_1d) + bias_correction1 = 1.0 - beta1_1d**step + bias_correction2 = 1.0 - beta2_1d**step + if amsgrad: + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + denom = ( + max_exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + else: + denom = ( + exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + real_update_tmp = ( + -lr + * lr_1d_factor + * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom)) + ) + p.data.add_(real_update_tmp) + return exp_avg, exp_avg_sq + + +class MARS(torch.optim.Optimizer): + def __init__( + self, + params, + lr=3e-3, + betas=(0.95, 0.99), + eps=1e-8, + weight_decay=0.0, + amsgrad=False, + gamma=0.025, + is_approx=True, + mars_type="mars-adamw", + optimize_1d=False, + lr_1d=3e-3, + betas_1d=(0.9, 0.95), + weight_decay_1d=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + assert mars_type in [ + "mars-adamw", + "mars-lion", + "mars-shampoo", + ], "MARS type not supported" + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + mars_type=mars_type, + gamma=gamma, + optimize_1d=optimize_1d, + weight_decay_1d=weight_decay_1d, + ) + super(MARS, self).__init__(params, defaults) + self.eps = eps + self.update_fn = update_fn + self.lr = lr + self.weight_decay = weight_decay + self.amsgrad = amsgrad + self.step_num = 0 + self.is_approx = is_approx + self.gamma = gamma + self.mars_type = mars_type + self.optimize_1d = optimize_1d + self.lr_1d_factor = lr_1d / lr + self.weight_decay_1d = weight_decay_1d + self.betas_1d = betas_1d + + @torch.no_grad() + def update_last_grad(self): + if not self.is_approx: + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + if "last_grad" not in state: + state["last_grad"] = torch.zeros_like(p) + state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) + + @torch.no_grad() + def update_previous_grad(self): + if not self.is_approx: + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + print(p, "grad is none") + continue + state = self.state[p] + if "previous_grad" not in state: + state["previous_grad"] = torch.zeros_like(p) + state["previous_grad"].zero_().add_(p.grad, alpha=1.0) + + def __setstate__(self, state): + super(MARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @torch.no_grad() + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + grad_scaler=None, + ): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + "FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." + ) + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + gamma = self.gamma + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group["params"]): + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group["amsgrad"] + + state = self.state[p] + # ('----- starting a parameter state', state.keys(), 'Length of state', len(state)) + # State initialization + if len(state) <= 1: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Last Gradient + state["last_grad"] = torch.zeros_like(p) + # state['previous_grad'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p.data) + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + last_grad = state["last_grad"] + lr, wd, beta1, beta2 = ( + group["lr"], + group["weight_decay"], + *group["betas"], + ) + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + else: + max_exp_avg_sq = 0 + + if "step" in state: + state["step"] += 1 + else: + state["step"] = 1 + step = state["step"] + is_grad_2d = len(grad.shape) == 2 + exp_avg, exp_avg_sq = self.update_fn( + p, + grad, + exp_avg, + exp_avg_sq, + lr, + wd, + beta1, + beta2, + last_grad, + self.eps, + amsgrad, + max_exp_avg_sq, + step, + gamma, + mars_type=self.mars_type, + is_grad_2d=is_grad_2d, + optimize_1d=self.optimize_1d, + lr_1d_factor=self.lr_1d_factor, + betas_1d=self.betas_1d, + weight_decay_1d=( + self.weight_decay if self.optimize_1d else self.weight_decay_1d + ), + ) + if self.is_approx: + state["last_grad"] = grad + self.step_num = step + + return loss \ No newline at end of file diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 9f09b372266..fe0c12395bb 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -118,6 +118,25 @@ class OptimizerConfig: ademamix_alpha_warmup: Optional[int] = None """Number of warmup steps used to increase alpha.""" + # Prodigy + prodigy_beta3: Optional[float] = None + """coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2.""" + + prodigy_decouple: bool = True + """Use AdamW style decoupled weight decay.""" + + prodigy_use_bias_correction: bool = False + """Turn on Adam's bias correction. Off by default.""" + + prodigy_safeguard_warmup: bool = False + """Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default.""" + + prodigy_fsdp_in_use: bool = False + """If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ ####################### # Distributed optimizer diff --git a/megatron/core/optimizer/prodigy.py b/megatron/core/optimizer/prodigy.py new file mode 100644 index 00000000000..c6bcb38d42d --- /dev/null +++ b/megatron/core/optimizer/prodigy.py @@ -0,0 +1,274 @@ +""" +Here is an original implementation of Prodigy. +Source: https://github.com/konstmish/prodigy +""" + +import math + +import torch +import torch.distributed as dist + + +class Prodigy(torch.optim.Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__( + self, + params, + lr=1.0, + betas=(0.9, 0.999), + beta3=None, + eps=1e-8, + weight_decay=0, + decouple=True, + use_bias_correction=False, + safeguard_warmup=False, + d0=1e-6, + d_coef=1.0, + growth_rate=float("inf"), + fsdp_in_use=False, + ): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict( + lr=lr, + betas=betas, + beta3=beta3, + eps=eps, + weight_decay=weight_decay, + d=d0, + d0=d0, + d_max=d0, + d_numerator=0.0, + d_coef=d_coef, + k=0, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use, + ) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group["use_bias_correction"] + beta1, beta2 = group["betas"] + beta3 = group["beta3"] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group["k"] + + d = group["d"] + d_max = group["d_max"] + d_coef = group["d_coef"] + lr = max(group["lr"] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group["growth_rate"] + decouple = group["decouple"] + fsdp_in_use = group["fsdp_in_use"] + + d_numerator = group["d_numerator"] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + group_lr = group["lr"] + d0 = group["d0"] + safeguard_warmup = group["safeguard_warmup"] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0" + ) + + for p in group["params"]: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if "step" not in state: + state["step"] = 0 + state["s"] = torch.zeros_like(p.data).detach() + state["p0"] = p.detach().clone() + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + s = state["s"] + p0 = state["p0"] + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += ( + (d / d0) + * dlr + * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + ) + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1 - beta2) + ) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + ###### + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group["d0"]: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group["d_numerator"] = global_d_numerator + group["d_denom"] = global_d_denom + group["d"] = d + group["d_max"] = d_max + group["d_hat"] = d_hat + + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-dlr) + + group["k"] = k + 1 + + return loss \ No newline at end of file diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c6c0928a458..a45bd12ea98 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1261,6 +1261,16 @@ def _add_regularization_args(parser): help='AdEMAMix warmup period for beta_3') group.add_argument('--ademamix-alpha-warmup', type=int, default=-1, help='AdEMAMix warmup period for aplha') + group.add_argument('--prodigy-beta3', type=float, default=None, + help='If set to None, uses the value of square root of beta2') + group.add_argument('--prodigy-decouple', type=bool, default=True, + help='Decoupled weight decay') + group.add_argument('--prodigy-use-bias-correction', type=bool, default=False, + help='Use bias correction') + group.add_argument('--prodigy-safeguard-warmup', type=bool, default=False, + help='Remove lr from the denominator of D estimate to avoid issues during warm-up stage') + group.add_argument('--prodigy-fsdp-in-use', type=bool, default=False, + help='If set, use FSDP') group.add_argument('--adam-eps', type=float, default=1e-08, help='Term added to the denominator to improve' 'numerical stability') @@ -1463,7 +1473,7 @@ def _add_training_args(parser): help='Enable bias only in the QKV linear layers', dest='add_qkv_bias') group.add_argument('--optimizer', type=str, default='adam', - choices=['adam', 'sgd', 'ademamix'], + choices=['adam', 'sgd', 'ademamix', 'prodigy'], help='Optimizer function') group.add_argument('--dataloader-type', type=str, default=None, choices=['single', 'cyclic', 'external'], From 116f37e3021cb9a93b927a4fd21064fb206aed4d Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 19 Feb 2025 11:22:06 +0400 Subject: [PATCH 2/9] mars is here, adopt todo --- megatron/core/optimizer/__init__.py | 38 +++++++++++++++++++- megatron/core/optimizer/distrib_optimizer.py | 12 +++++++ megatron/core/optimizer/optimizer_config.py | 31 ++++++++++++++++ megatron/training/arguments.py | 23 +++++++++++- megatron/training/training.py | 6 +++- 5 files changed, 107 insertions(+), 3 deletions(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index cffaf27cc05..e713a5133f1 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -25,7 +25,7 @@ from .ademamix import AdEMAMix from .prodigy import Prodigy -from .mars import MARS +from .mars import MARS, exists from .adopt import ADOPT from megatron.core import mpu @@ -353,6 +353,42 @@ def init_state_fn(opt, config=None): opt.state[p]['p0'] = p.detach().clone() opt.state[p]['exp_avg'] = torch.zeros_like(p.data).detach() opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data).detach() + else: + opt.initialize_state(p) + + elif config.optimizer == 'mars': + kwargs = { + "params": param_groups, + "lr": config.mars_lr, + "betas": (config.mars_beta1, config.mars_beta2), + "weight_decay": config.weight_decay, + "amsgrad": config.mars_amsgrad, + "gamma": config.mars_vr_gamma, + "is_approx": config.mars_is_approx, + "mars_type": config.mars_type, + "optimize_1d": config.mars_optimize_1d, + "lr_1d": config.lr, + "betas_1d": (config.adam_beta1, config.adam_beta2), + "weight_decay_1d": config.mars_weight_decay_1d, + } + + optimizer = MARS(**kwargs) + + def init_state_fn(opt, config=None): + for group in opt.param_groups: + for p in filter(lambda p: exists(p.grad), group["params"]): + amsgrad = group["amsgrad"] + if len(opt.state[p]) <= 1: + opt.state[p]["step"] = 0 + opt.state[p]["exp_avg"] = torch.zeros_like(p.data) + opt.state[p]["last_grad"] = torch.zeros_like(p.data) + opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) + if amsgrad: + opt.state[p]["max_exp_avg_sq"] = torch.zeros_like(p.data) + if amsgrad and "max_exp_avg_sq" not in opt.state[p]: + opt.state[p]["max_exp_avg_sq"] = torch.zeros_like(p.data) + else: + opt.initialize_state(p) elif config.optimizer == 'sgd': diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index f935eb39dbd..d9ad66a5018 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -23,6 +23,7 @@ from .ademamix import AdEMAMix from .prodigy import Prodigy +from .mars import MARS from .. import tensor_parallel from ..config_logger import has_config_logger_enabled, log_config_to_disk @@ -505,6 +506,12 @@ def __init__( elif isinstance(optimizer, Prodigy): self.optimizer_name = 'prodigy' self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "s", "p0") + elif isinstance(optimizer, MARS): + self.optimizer_name = 'mars' + if config.mars_amsgrad: + self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "last_grad", "max_exp_avg_sq") + else: + self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "last_grad") else: raise Exception(f"Unrecognized optimizer {type(optimizer)}.") @@ -720,6 +727,11 @@ def load_state_dict(self, state_dict): tensors = {"exp_avg_slow": init_shard(), "exp_avg_sq": init_shard()} elif self.optimizer_name == 'prodigy': tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "s": init_shard(), "p0": init_shard()} + elif self.optimizer_name == 'mars': + if len(self.optimizer_keys) == 5: + tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "last_grad": init_shard(), "max_exp_avg_sq": init_shard()} + else: + tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "last_grad": init_shard()} if self.config.use_precision_aware_optimizer: tensors["master_param"] = init_shard() state_dict_state.append((state_order, tensors)) diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index fe0c12395bb..6a795e5b3a2 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -138,6 +138,37 @@ class OptimizerConfig: than PyTorch's builtin version, the auto-detection won't work. """ + mars_beta1: float = 0.95 + """First coefficient for computing running averages of gradient and its square in Adam + optimizer. + """ + + mars_beta2: float = 0.99 + """Second coefficient for computing running averages of gradient and its square in Adam + optimizer. + """ + + mars_type: str = 'mars-adamw' + """Which version of the MARS framework to use.""" + + mars_vr_gamma: float = 0.025 + """The gamma parameter for the variance reduction term in MARS.""" + + mars_is_approx: bool = True + """Whether to use the approximate version of the MARS optimizer.""" + + mars_lr: float = 0.003 + """The learning rate for the MARS optimizer.""" + + mars_amsgrad: bool = False + """Whether to use the AMSGrad variant of the MARS optimizer.""" + + mars_optimize_1d: bool = False + """If set to False, we optimize 1D parameters with AdamW.""" + + mars_weight_decay_1d: float = 0.1 + """The weight decay for 1D parameters in MARS.""" + ####################### # Distributed optimizer ####################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a45bd12ea98..b4b80b62b09 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1271,6 +1271,27 @@ def _add_regularization_args(parser): help='Remove lr from the denominator of D estimate to avoid issues during warm-up stage') group.add_argument('--prodigy-fsdp-in-use', type=bool, default=False, help='If set, use FSDP') + group.add_argument('--mars-beta1', type=float, default=0.95, + help='First coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--mars-beta2', type=float, default=0.99, + help='Second coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--mars-type', type=str, default='mars-adamw', choices=['mars-adamw', 'mars-lion', 'mars-shampoo'], + help='Which version of the MARS framework to use') + group.add_argument('--mars-vr-gamma', type=float, default=0.025, + helpt='Variance Reduction scaling factor') + group.add_argument('--mars-is-approx', type=bool, default=True, + help='If set, use the approximate version of MARS') + group.add_argument('--mars-lr', type=float, default=0.003, + help='Learning rate for MARS') + group.add_argument('--mars-amsgrad', type=bool, default=False, + help='If set, use AMSGrad for MARS') + group.add_argument('--mars-optimize-1d', type=bool, default=False, + help='If set to False, we optimize 1D parameters with AdamW') + group.add_argument('--mars-weight-decay-1d', type=float, default=0.1, + help='Weight decay for 1D parameters') + group.add_argument() group.add_argument('--adam-eps', type=float, default=1e-08, help='Term added to the denominator to improve' 'numerical stability') @@ -1473,7 +1494,7 @@ def _add_training_args(parser): help='Enable bias only in the QKV linear layers', dest='add_qkv_bias') group.add_argument('--optimizer', type=str, default='adam', - choices=['adam', 'sgd', 'ademamix', 'prodigy'], + choices=['adam', 'sgd', 'ademamix', 'prodigy', 'mars', 'adopt'], help='Optimizer function') group.add_argument('--dataloader-type', type=str, default=None, choices=['single', 'cyclic', 'external'], diff --git a/megatron/training/training.py b/megatron/training/training.py index caa574494de..c9f183c6340 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -784,7 +784,11 @@ def train_step(forward_step_func, data_iterator, # Set grad to zero. for model_chunk in model: model_chunk.zero_grad_buffer() - optimizer.zero_grad() + if optimizer.__class__.__name__ == "MARS": + optimizer.zero_grad(set_to_none=True) + optimizer.update_last_grad() + else: + optimizer.zero_grad() # Forward pass. forward_backward_func = get_forward_backward_func() From 1f071c421d10e01569eb328ed5db498cbe4379ac Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 19 Feb 2025 11:43:20 +0400 Subject: [PATCH 3/9] adopt is here --- megatron/core/optimizer/__init__.py | 21 ++++++++++++++++++++ megatron/core/optimizer/distrib_optimizer.py | 6 ++++++ megatron/core/optimizer/optimizer_config.py | 6 ++++++ megatron/training/arguments.py | 6 +++++- 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index e713a5133f1..0440e36b4ff 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -390,6 +390,27 @@ def init_state_fn(opt, config=None): else: opt.initialize_state(p) + elif config.optimizer == 'adopt': + kwargs = { + "params": param_groups, + "lr": config.lr, + "weight_decay": config.weight_decay, + "betas": (config.adam_beta1, config.adam_beta2), + "eps": config.adopt_eps, + "decouple": config.adopt_decouple, + } + + optimizer = ADOPT(**kwargs) + + def init_state_fn(opt, config=None): + for group in opt.param_groups: + for p in group['params']: + if len(opt.state[p]) == 0: + opt.state[p]['step'] = torch.tensor(0.0) + opt.state[p]['exp_avg'] = torch.zeros_like(p.data) + opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) + else: + opt.initialize_state(p) elif config.optimizer == 'sgd': optimizer = SGD( diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index d9ad66a5018..521ece784f9 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -24,6 +24,7 @@ from .ademamix import AdEMAMix from .prodigy import Prodigy from .mars import MARS +from .adopt import ADOPT from .. import tensor_parallel from ..config_logger import has_config_logger_enabled, log_config_to_disk @@ -512,6 +513,9 @@ def __init__( self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "last_grad", "max_exp_avg_sq") else: self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "last_grad") + elif isinstance(optimizer, ADOPT): + self.optimizer_name = 'adopt' + self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq") else: raise Exception(f"Unrecognized optimizer {type(optimizer)}.") @@ -732,6 +736,8 @@ def load_state_dict(self, state_dict): tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "last_grad": init_shard(), "max_exp_avg_sq": init_shard()} else: tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "last_grad": init_shard()} + elif self.optimizer_name == 'adopt': + tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()} if self.config.use_precision_aware_optimizer: tensors["master_param"] = init_shard() state_dict_state.append((state_order, tensors)) diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 6a795e5b3a2..8adf2a06ddf 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -169,6 +169,12 @@ class OptimizerConfig: mars_weight_decay_1d: float = 0.1 """The weight decay for 1D parameters in MARS.""" + adopt_eps: float = 1e-6 + """Term added to the denominator to improve numerical stability in ADOPT optimizer.""" + + adopt_decouple: bool = True + """Use AdamW style decoupled weight decay.""" + ####################### # Distributed optimizer ####################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b4b80b62b09..77bfe0c8722 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1291,7 +1291,11 @@ def _add_regularization_args(parser): help='If set to False, we optimize 1D parameters with AdamW') group.add_argument('--mars-weight-decay-1d', type=float, default=0.1, help='Weight decay for 1D parameters') - group.add_argument() + group.add_argument('--adopt-eps', type=float, default=1e-6, + help='Term added to the denominator to improve' + 'numerical stability') + group.add_argument('--adopt-decouple', type=bool, default=True, + help='Decoupled weight decay') group.add_argument('--adam-eps', type=float, default=1e-08, help='Term added to the denominator to improve' 'numerical stability') From 32a258bc0c3018abb4c24e95556487ca335a6de6 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 19 Feb 2025 18:19:26 +0400 Subject: [PATCH 4/9] --fix step adopt --- megatron/core/optimizer/adopt.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/megatron/core/optimizer/adopt.py b/megatron/core/optimizer/adopt.py index f5db8021661..d5b7edabf5f 100644 --- a/megatron/core/optimizer/adopt.py +++ b/megatron/core/optimizer/adopt.py @@ -54,7 +54,6 @@ def step(self, closure=None): grads = [] exp_avgs = [] exp_avg_sqs = [] - state_steps = [] beta1, beta2 = group["betas"] for p in group["params"]: @@ -70,7 +69,7 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: - state["step"] = torch.tensor(0.0) + state["step"] = 0 state["exp_avg"] = torch.zeros_like( p, memory_format=torch.preserve_format ) @@ -80,18 +79,17 @@ def step(self, closure=None): exp_avgs.append(state["exp_avg"]) exp_avg_sqs.append(state["exp_avg_sq"]) - state_steps.append(state["step"]) for i, param in enumerate(params_with_grad): grad = grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - step = int(step_t.item()) + state = self.state[param] + step = state["step"] if step == 0: exp_avg_sq.addcmul_(grad, grad) - step_t += 1 + state["step"] += 1 continue if group["weight_decay"] != 0: @@ -111,6 +109,6 @@ def step(self, closure=None): param.data.add_(exp_avg, alpha=-group["lr"]) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - step_t += 1 + state["step"] += 1 return loss \ No newline at end of file From c2e704b3c7d0a706c8542e39db83a87cd2b0b325 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 20 Feb 2025 16:39:40 +0400 Subject: [PATCH 5/9] --minor --- megatron/training/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 77bfe0c8722..7b382ae1a24 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1280,7 +1280,7 @@ def _add_regularization_args(parser): group.add_argument('--mars-type', type=str, default='mars-adamw', choices=['mars-adamw', 'mars-lion', 'mars-shampoo'], help='Which version of the MARS framework to use') group.add_argument('--mars-vr-gamma', type=float, default=0.025, - helpt='Variance Reduction scaling factor') + help='Variance Reduction scaling factor') group.add_argument('--mars-is-approx', type=bool, default=True, help='If set, use the approximate version of MARS') group.add_argument('--mars-lr', type=float, default=0.003, From a3c80dbf6f14a2a92e01e39c82f12f7c9b706ab8 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Fri, 21 Feb 2025 15:27:29 +0400 Subject: [PATCH 6/9] wandb-entity --- megatron/training/global_vars.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index 70701341ec4..5fd4b4689f9 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -187,6 +187,7 @@ def _set_wandb_writer(args): 'dir': save_dir, 'name': args.wandb_exp_name, 'project': args.wandb_project, + 'entity': args.wandb_entity, 'config': vars(args)} os.makedirs(wandb_kwargs['dir'], exist_ok=True) wandb.init(**wandb_kwargs) From 919910604628739d8a1e83e176d8c1f4189ddc35 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Fri, 21 Feb 2025 15:32:31 +0400 Subject: [PATCH 7/9] wandb-entity --- megatron/training/arguments.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 0a29eecb95a..e465157d84e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1220,6 +1220,8 @@ def _add_logging_args(parser): help='Enable world size logging to tensorboard.') group.add_argument('--wandb-project', type=str, default='', help='The wandb project name. Ignore wandb by default.') + group.add_argument('--wandb-entity', type=str, default='', + help='The wandb entity name. Ignore wandb by default.') group.add_argument('--wandb-exp-name', type=str, default='', help='The wandb experiment name.') group.add_argument('--wandb-save-dir', type=str, default='', From bdf50579329e8818dd2ceb03accf13fb67c1cd7b Mon Sep 17 00:00:00 2001 From: Andron00e Date: Mon, 24 Feb 2025 17:44:52 +0400 Subject: [PATCH 8/9] fixed adopt --- megatron/core/optimizer/__init__.py | 7 +-- megatron/core/optimizer/adopt.py | 89 ++++++++++++++++------------- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 0440e36b4ff..07c60ee3bfe 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -404,11 +404,10 @@ def init_state_fn(opt, config=None): def init_state_fn(opt, config=None): for group in opt.param_groups: - for p in group['params']: + for p in group["params"]: if len(opt.state[p]) == 0: - opt.state[p]['step'] = torch.tensor(0.0) - opt.state[p]['exp_avg'] = torch.zeros_like(p.data) - opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) + opt.state[p]["exp_avg"] = torch.zeros_like(p.data) + opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) else: opt.initialize_state(p) diff --git a/megatron/core/optimizer/adopt.py b/megatron/core/optimizer/adopt.py index d5b7edabf5f..7b477d83174 100644 --- a/megatron/core/optimizer/adopt.py +++ b/megatron/core/optimizer/adopt.py @@ -15,6 +15,27 @@ def exists(val): import torch +""" +Here is an original implementation of ADOPT. +Source: https://github.com/iShohei220/adopt +""" + +import torch + + +def exists(val): + return val is not None + + +from typing import Callable, Optional, Tuple + +import torch + + +def adopt_clip_fn(step: int) -> float: + return step ** 0.25 + + class ADOPT(torch.optim.Optimizer): def __init__( self, @@ -22,7 +43,7 @@ def __init__( lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.9999), eps: float = 1e-6, - clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25, + clip_lambda: Optional[Callable[[int], float]] = adopt_clip_fn, weight_decay: float = 0.0, decouple: bool = True, ): @@ -37,12 +58,18 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") - self.clip_lambda = clip_lambda defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, decouple=decouple + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + decouple=decouple, + clip_lambda=clip_lambda, + step=0, ) super().__init__(params, defaults) + @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: @@ -50,65 +77,47 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] + group["step"] += 1 + step = group["step"] beta1, beta2 = group["betas"] + lr = group["lr"] for p in group["params"]: if p.grad is None: continue + grad = p.grad - if p.grad.is_sparse: + if grad.is_sparse: raise RuntimeError("ADOPT does not support sparse gradients") - params_with_grad.append(p) - grads.append(p.grad) - state = self.state[p] - if len(state) == 0: - state["step"] = 0 - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avgs.append(state["exp_avg"]) - exp_avg_sqs.append(state["exp_avg_sq"]) - - for i, param in enumerate(params_with_grad): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - state = self.state[param] - step = state["step"] - - if step == 0: + if len(state) ==0: + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + if step == 1: exp_avg_sq.addcmul_(grad, grad) - state["step"] += 1 continue if group["weight_decay"] != 0: if group["decouple"]: - param.data.mul_(1 - group["lr"] * group["weight_decay"]) + p.data.mul_(1 - lr * group["weight_decay"]) else: - grad = grad.add(param, alpha=group["weight_decay"]) + grad = grad.add(p, alpha=group["weight_decay"]) denom = torch.clamp(exp_avg_sq.sqrt(), group["eps"]) normed_grad = grad.div(denom) - if self.clip_lambda is not None: - clip = self.clip_lambda(step) + if group["clip_lambda"] is not None: + clip = group["clip_lambda"](step) normed_grad.clamp_(-clip, clip) exp_avg.lerp_(normed_grad, 1 - beta1) - param.data.add_(exp_avg, alpha=-group["lr"]) + p.data.add_(exp_avg, alpha=-lr) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - state["step"] += 1 - - return loss \ No newline at end of file + return loss From 36ed86282a8b245c86ce1f4463d25e6d55a00afa Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 26 Feb 2025 17:31:39 +0400 Subject: [PATCH 9/9] fixed adopt --- megatron/core/optimizer/adopt.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/megatron/core/optimizer/adopt.py b/megatron/core/optimizer/adopt.py index 7b477d83174..7b6c8ced179 100644 --- a/megatron/core/optimizer/adopt.py +++ b/megatron/core/optimizer/adopt.py @@ -15,23 +15,6 @@ def exists(val): import torch -""" -Here is an original implementation of ADOPT. -Source: https://github.com/iShohei220/adopt -""" - -import torch - - -def exists(val): - return val is not None - - -from typing import Callable, Optional, Tuple - -import torch - - def adopt_clip_fn(step: int) -> float: return step ** 0.25 @@ -77,8 +60,8 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - group["step"] += 1 - step = group["step"] + group['step'] += 1 + step = group['step'] beta1, beta2 = group["betas"] lr = group["lr"]