Skip to content

Commit

Permalink
linear sched with wsd sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 9, 2024
1 parent 7b5daf7 commit eb104f2
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/zeroband/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from torch.optim.lr_scheduler import LambdaLR
from functools import partial
import math

def _get_linear_schedule_with_wsd_sqrt_lr_lambda(current_step: int, *, num_warmup_steps: int, num_stable_steps: int, num_training_steps: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step < num_stable_steps:
return 1.0
else:
return max(0.0, 1 - math.sqrt(float(current_step - num_stable_steps) / float(num_training_steps - num_stable_steps)))

def get_linear_schedule_with_wsd_sqrt(optimizer, num_warmup_steps: int, num_stable_steps: int, num_training_steps: int, last_epoch: int=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

lr_lambda = partial(
_get_linear_schedule_with_wsd_sqrt_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_training_steps=num_training_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

0 comments on commit eb104f2

Please sign in to comment.