Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from torch.optim import AdamW as Adam, SGD

from .ademamix import AdEMAMix
from .prodigy import Prodigy
from .mars import MARS, exists
from .adopt import ADOPT

from megatron.core import mpu

Expand Down Expand Up @@ -326,6 +329,87 @@ def init_state_fn(opt, config=None):
else:
opt.initialize_state(p)

elif config.optimizer == 'prodigy':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"beta3": config.prodigy_beta3,
"decouple": config.prodigy_decouple,
"use_bias_correction": config.prodigy_use_bias_correction,
"safeguard_warmup": config.prodigy_safeguard_warmup,
"fsdp_in_use": config.prodigy_fsdp_in_use,
}

optimizer = Prodigy(**kwargs)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if 'step' not in opt.state[p]:
opt.state[p]['step'] = 0
opt.state[p]['s'] = torch.zeros_like(p.data).detach()
opt.state[p]['p0'] = p.detach().clone()
opt.state[p]['exp_avg'] = torch.zeros_like(p.data).detach()
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data).detach()
else:
opt.initialize_state(p)

elif config.optimizer == 'mars':
kwargs = {
"params": param_groups,
"lr": config.mars_lr,
"betas": (config.mars_beta1, config.mars_beta2),
"weight_decay": config.weight_decay,
"amsgrad": config.mars_amsgrad,
"gamma": config.mars_vr_gamma,
"is_approx": config.mars_is_approx,
"mars_type": config.mars_type,
"optimize_1d": config.mars_optimize_1d,
"lr_1d": config.lr,
"betas_1d": (config.adam_beta1, config.adam_beta2),
"weight_decay_1d": config.mars_weight_decay_1d,
}

optimizer = MARS(**kwargs)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in filter(lambda p: exists(p.grad), group["params"]):
amsgrad = group["amsgrad"]
if len(opt.state[p]) <= 1:
opt.state[p]["step"] = 0
opt.state[p]["exp_avg"] = torch.zeros_like(p.data)
opt.state[p]["last_grad"] = torch.zeros_like(p.data)
opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data)
if amsgrad:
opt.state[p]["max_exp_avg_sq"] = torch.zeros_like(p.data)
if amsgrad and "max_exp_avg_sq" not in opt.state[p]:
opt.state[p]["max_exp_avg_sq"] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)

elif config.optimizer == 'adopt':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adopt_eps,
"decouple": config.adopt_decouple,
}

optimizer = ADOPT(**kwargs)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group["params"]:
if len(opt.state[p]) == 0:
opt.state[p]["exp_avg"] = torch.zeros_like(p.data)
opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)

elif config.optimizer == 'sgd':
optimizer = SGD(
Expand Down
106 changes: 106 additions & 0 deletions megatron/core/optimizer/adopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Here is an original implementation of ADOPT.
Source: https://github.com/iShohei220/adopt
"""

import torch


def exists(val):
return val is not None


from typing import Callable, Optional, Tuple

import torch


def adopt_clip_fn(step: int) -> float:
return step ** 0.25


class ADOPT(torch.optim.Optimizer):
def __init__(
self,
params,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
clip_lambda: Optional[Callable[[int], float]] = adopt_clip_fn,
weight_decay: float = 0.0,
decouple: bool = True,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
decouple=decouple,
clip_lambda=clip_lambda,
step=0,
)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
group['step'] += 1
step = group['step']
beta1, beta2 = group["betas"]
lr = group["lr"]

for p in group["params"]:
if p.grad is None:
continue
grad = p.grad

if grad.is_sparse:
raise RuntimeError("ADOPT does not support sparse gradients")

state = self.state[p]

if len(state) ==0:
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]

if step == 1:
exp_avg_sq.addcmul_(grad, grad)
continue

if group["weight_decay"] != 0:
if group["decouple"]:
p.data.mul_(1 - lr * group["weight_decay"])
else:
grad = grad.add(p, alpha=group["weight_decay"])

denom = torch.clamp(exp_avg_sq.sqrt(), group["eps"])
normed_grad = grad.div(denom)

if group["clip_lambda"] is not None:
clip = group["clip_lambda"](step)
normed_grad.clamp_(-clip, clip)

exp_avg.lerp_(normed_grad, 1 - beta1)
p.data.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

return loss
26 changes: 25 additions & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
HAVE_APEX_OR_TE = False

from .ademamix import AdEMAMix
from .prodigy import Prodigy
from .mars import MARS
from .adopt import ADOPT

from .. import tensor_parallel
from ..config_logger import has_config_logger_enabled, log_config_to_disk
Expand Down Expand Up @@ -501,8 +504,20 @@ def __init__(
self.optimizer_keys = ("param", "exp_avg_slow", "exp_avg_fast", "exp_avg_sq")
else:
self.optimizer_keys = ("param", "exp_avg_slow", "exp_avg_sq")
elif isinstance(optimizer, Prodigy):
self.optimizer_name = 'prodigy'
self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "s", "p0")
elif isinstance(optimizer, MARS):
self.optimizer_name = 'mars'
if config.mars_amsgrad:
self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "last_grad", "max_exp_avg_sq")
else:
self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq", "last_grad")
elif isinstance(optimizer, ADOPT):
self.optimizer_name = 'adopt'
self.optimizer_keys = ("param", "exp_avg", "exp_avg_sq")
else:
raise Exception(f"Unrecognized optimizer {type(optimizer)}, only Adam and AdEMAMix are supported for now.")
raise Exception(f"Unrecognized optimizer {type(optimizer)}.")

# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
Expand Down Expand Up @@ -714,6 +729,15 @@ def load_state_dict(self, state_dict):
tensors = {"exp_avg_slow": init_shard(), "exp_avg_fast": init_shard(), "exp_avg_sq": init_shard()}
else: # beta1 == 0
tensors = {"exp_avg_slow": init_shard(), "exp_avg_sq": init_shard()}
elif self.optimizer_name == 'prodigy':
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "s": init_shard(), "p0": init_shard()}
elif self.optimizer_name == 'mars':
if len(self.optimizer_keys) == 5:
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "last_grad": init_shard(), "max_exp_avg_sq": init_shard()}
else:
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard(), "last_grad": init_shard()}
elif self.optimizer_name == 'adopt':
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}
if self.config.use_precision_aware_optimizer:
tensors["master_param"] = init_shard()
state_dict_state.append((state_order, tensors))
Expand Down
Loading