Skip to content

Commit 2467de2

Browse files
azahed98Arsh Zahed
andauthored
Add lr scheduler, weight decay and max_grad_norm (#214)
* Add weight decay and max_grad_norm * Change to min_lr_ratio * Update max_grad_norm and weight_decay defaults, supported vals * Add hints for disabling max_grad_nomr * add back learning rate hint * add max_grad_norm and weight_decay to FinetuneRequest * Remove percentage from min_lr_ratio description * Fix hints and typing * Update version to 1.3.5 * Fix more typing * Make min_lr_ratio optional --------- Co-authored-by: Arsh Zahed <arshzahed@Arshs-MacBook-Pro.local>
1 parent 1eb7779 commit 2467de2

File tree

5 files changed

+95
-4
lines changed

5 files changed

+95
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.3.4"
15+
version = "1.3.5"
1616
authors = [
1717
"Together AI <support@together.ai>"
1818
]

src/together/cli/api/finetune.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,30 @@ def fine_tuning(ctx: click.Context) -> None:
6565
)
6666
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
6767
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
68+
@click.option(
69+
"--min-lr-ratio",
70+
type=float,
71+
default=0.0,
72+
help="The ratio of the final learning rate to the peak learning rate",
73+
)
6874
@click.option(
6975
"--warmup-ratio",
7076
type=float,
7177
default=0.0,
7278
help="Warmup ratio for learning rate scheduler.",
7379
)
80+
@click.option(
81+
"--max-grad-norm",
82+
type=float,
83+
default=1.0,
84+
help="Max gradient norm to be used for gradient clipping. Set to 0 to disable.",
85+
)
86+
@click.option(
87+
"--weight-decay",
88+
type=float,
89+
default=0.0,
90+
help="Weight decay",
91+
)
7492
@click.option(
7593
"--lora/--no-lora",
7694
type=bool,
@@ -115,7 +133,10 @@ def create(
115133
n_checkpoints: int,
116134
batch_size: int | Literal["max"],
117135
learning_rate: float,
136+
min_lr_ratio: float,
118137
warmup_ratio: float,
138+
max_grad_norm: float,
139+
weight_decay: float,
119140
lora: bool,
120141
lora_r: int,
121142
lora_dropout: float,
@@ -138,7 +159,10 @@ def create(
138159
n_checkpoints=n_checkpoints,
139160
batch_size=batch_size,
140161
learning_rate=learning_rate,
162+
min_lr_ratio=min_lr_ratio,
141163
warmup_ratio=warmup_ratio,
164+
max_grad_norm=max_grad_norm,
165+
weight_decay=weight_decay,
142166
lora=lora,
143167
lora_r=lora_r,
144168
lora_dropout=lora_dropout,

src/together/resources/finetune.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
TogetherClient,
2121
TogetherRequest,
2222
TrainingType,
23+
FinetuneLRScheduler,
24+
FinetuneLinearLRSchedulerArgs,
2325
)
2426
from together.types.finetune import DownloadCheckpointType
2527
from together.utils import log_warn_once, normalize_key
@@ -35,7 +37,10 @@ def createFinetuneRequest(
3537
n_checkpoints: int | None = 1,
3638
batch_size: int | Literal["max"] = "max",
3739
learning_rate: float | None = 0.00001,
38-
warmup_ratio: float | None = 0.0,
40+
min_lr_ratio: float = 0.0,
41+
warmup_ratio: float = 0.0,
42+
max_grad_norm: float = 1.0,
43+
weight_decay: float = 0.0,
3944
lora: bool = False,
4045
lora_r: int | None = None,
4146
lora_dropout: float | None = 0,
@@ -83,6 +88,20 @@ def createFinetuneRequest(
8388
if warmup_ratio > 1 or warmup_ratio < 0:
8489
raise ValueError("Warmup ratio should be between 0 and 1")
8590

91+
if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
92+
raise ValueError("Min learning rate ratio should be between 0 and 1")
93+
94+
if max_grad_norm < 0:
95+
raise ValueError("Max gradient norm should be non-negative")
96+
97+
if weight_decay is not None and (weight_decay < 0):
98+
raise ValueError("Weight decay should be non-negative")
99+
100+
lrScheduler = FinetuneLRScheduler(
101+
lr_scheduler_type="linear",
102+
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
103+
)
104+
86105
finetune_request = FinetuneRequest(
87106
model=model,
88107
training_file=training_file,
@@ -92,7 +111,10 @@ def createFinetuneRequest(
92111
n_checkpoints=n_checkpoints,
93112
batch_size=batch_size,
94113
learning_rate=learning_rate,
114+
lr_scheduler=lrScheduler,
95115
warmup_ratio=warmup_ratio,
116+
max_grad_norm=max_grad_norm,
117+
weight_decay=weight_decay,
96118
training_type=training_type,
97119
suffix=suffix,
98120
wandb_key=wandb_api_key,
@@ -117,7 +139,10 @@ def create(
117139
n_checkpoints: int | None = 1,
118140
batch_size: int | Literal["max"] = "max",
119141
learning_rate: float | None = 0.00001,
120-
warmup_ratio: float | None = 0.0,
142+
min_lr_ratio: float = 0.0,
143+
warmup_ratio: float = 0.0,
144+
max_grad_norm: float = 1.0,
145+
weight_decay: float = 0.0,
121146
lora: bool = False,
122147
lora_r: int | None = None,
123148
lora_dropout: float | None = 0,
@@ -143,7 +168,11 @@ def create(
143168
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
144169
learning_rate (float, optional): Learning rate multiplier to use for training
145170
Defaults to 0.00001.
171+
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
172+
the learning rate scheduler. Defaults to 0.0.
146173
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
174+
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
175+
weight_decay (float, optional): Weight decay. Defaults to 0.0.
147176
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
148177
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
149178
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
@@ -185,7 +214,10 @@ def create(
185214
n_checkpoints=n_checkpoints,
186215
batch_size=batch_size,
187216
learning_rate=learning_rate,
217+
min_lr_ratio=min_lr_ratio,
188218
warmup_ratio=warmup_ratio,
219+
max_grad_norm=max_grad_norm,
220+
weight_decay=weight_decay,
189221
lora=lora,
190222
lora_r=lora_r,
191223
lora_dropout=lora_dropout,
@@ -436,7 +468,10 @@ async def create(
436468
n_checkpoints: int | None = 1,
437469
batch_size: int | Literal["max"] = "max",
438470
learning_rate: float | None = 0.00001,
439-
warmup_ratio: float | None = 0.0,
471+
min_lr_ratio: float = 0.0,
472+
warmup_ratio: float = 0.0,
473+
max_grad_norm: float = 1.0,
474+
weight_decay: float = 0.0,
440475
lora: bool = False,
441476
lora_r: int | None = None,
442477
lora_dropout: float | None = 0,
@@ -462,7 +497,11 @@ async def create(
462497
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
463498
learning_rate (float, optional): Learning rate multiplier to use for training
464499
Defaults to 0.00001.
500+
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
501+
the learning rate scheduler. Defaults to 0.0.
465502
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
503+
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
504+
weight_decay (float, optional): Weight decay. Defaults to 0.0.
466505
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
467506
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
468507
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
@@ -504,7 +543,10 @@ async def create(
504543
n_checkpoints=n_checkpoints,
505544
batch_size=batch_size,
506545
learning_rate=learning_rate,
546+
min_lr_ratio=min_lr_ratio,
507547
warmup_ratio=warmup_ratio,
548+
max_grad_norm=max_grad_norm,
549+
weight_decay=weight_decay,
508550
lora=lora,
509551
lora_r=lora_r,
510552
lora_dropout=lora_dropout,

src/together/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
LoRATrainingType,
3131
TrainingType,
3232
FinetuneTrainingLimits,
33+
FinetuneLRScheduler,
34+
FinetuneLinearLRSchedulerArgs,
3335
)
3436
from together.types.images import (
3537
ImageRequest,
@@ -57,6 +59,8 @@
5759
"FinetuneList",
5860
"FinetuneListEvents",
5961
"FinetuneDownloadResult",
62+
"FinetuneLRScheduler",
63+
"FinetuneLinearLRSchedulerArgs",
6064
"FileRequest",
6165
"FileResponse",
6266
"FileList",

src/together/types/finetune.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,14 @@ class FinetuneRequest(BaseModel):
150150
n_epochs: int
151151
# training learning rate
152152
learning_rate: float
153+
# learning rate scheduler type and args
154+
lr_scheduler: FinetuneLRScheduler | None = None
153155
# learning rate warmup ratio
154156
warmup_ratio: float
157+
# max gradient norm
158+
max_grad_norm: float
159+
# weight decay
160+
weight_decay: float
155161
# number of checkpoints to save
156162
n_checkpoints: int | None = None
157163
# number of evaluation loops to run
@@ -193,8 +199,14 @@ class FinetuneResponse(BaseModel):
193199
batch_size: int | None = None
194200
# training learning rate
195201
learning_rate: float | None = None
202+
# learning rate scheduler type and args
203+
lr_scheduler: FinetuneLRScheduler | None = None
196204
# learning rate warmup ratio
197205
warmup_ratio: float | None = None
206+
# max gradient norm
207+
max_grad_norm: float | None = None
208+
# weight decay
209+
weight_decay: float | None = None
198210
# number of steps between evals
199211
eval_steps: int | None = None
200212
# training type
@@ -287,3 +299,12 @@ class FinetuneTrainingLimits(BaseModel):
287299
min_learning_rate: float
288300
full_training: FinetuneFullTrainingLimits | None = None
289301
lora_training: FinetuneLoraTrainingLimits | None = None
302+
303+
304+
class FinetuneLRScheduler(BaseModel):
305+
lr_scheduler_type: str
306+
lr_scheduler_args: FinetuneLinearLRSchedulerArgs | None = None
307+
308+
309+
class FinetuneLinearLRSchedulerArgs(BaseModel):
310+
min_lr_ratio: float | None = 0.0

0 commit comments

Comments
 (0)