55import yaml
66from data import setup_data
77from ignite .engine import Events
8+ from ignite .handlers import PiecewiseLinear
89from ignite .metrics import Accuracy , Loss
910from ignite .utils import manual_seed
1011from models import setup_model
@@ -29,6 +30,15 @@ def run(local_rank: int, config: Any):
2930 model = idist .auto_model (setup_model (config .model ))
3031 optimizer = idist .auto_optim (optim .Adam (model .parameters (), lr = config .lr ))
3132 loss_fn = nn .CrossEntropyLoss ().to (device = device )
33+ milestones_values = [
34+ (0 , 0.0 ),
35+ (
36+ len (dataloader_train ),
37+ config .lr ,
38+ ),
39+ (config .max_epochs * len (dataloader_train ), 0.0 ),
40+ ]
41+ lr_scheduler = PiecewiseLinear (optimizer , "lr" , milestones_values = milestones_values )
3242
3343 # trainer and evaluator
3444 trainer = setup_trainer (
@@ -53,10 +63,17 @@ def run(local_rank: int, config: Any):
5363 (config .output_dir / "config-lock.yaml" ).write_text (yaml .dump (config ))
5464 trainer .logger = evaluator .logger = logger
5565
66+ trainer .add_event_handler (Events .ITERATION_COMPLETED , lr_scheduler )
67+
5668 # setup ignite handlers
5769 #::: if (it.save_training || it.save_evaluation) { :::#
5870 #::: if (it.save_training) { :::#
59- to_save_train = {"model" : model , "optimizer" : optimizer , "trainer" : trainer }
71+ to_save_train = {
72+ "model" : model ,
73+ "optimizer" : optimizer ,
74+ "trainer" : trainer ,
75+ "lr_scheduler" : lr_scheduler ,
76+ }
6077 #::: } else { :::#
6178 to_save_train = None
6279 #::: } :::#
0 commit comments