@@ -351,7 +351,8 @@ def get_model_parallel_group(self):
351
351
return None
352
352
353
353
def state_dict (self ):
354
- """
354
+ """Return the state dict of this optimizer.
355
+
355
356
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
356
357
related) optimizer variables. The returned state dict can be stored in
357
358
the standard model/RNG checkpoint file. The parameter and dependent
@@ -371,10 +372,10 @@ def state_dict(self):
371
372
state_dict = {}
372
373
373
374
# Optimizer state (do not store parameter state here).
374
- state_dict ['optimizer' ] = {k : v for k , v in self .optimizer .state_dict ().items () if k != " state" }
375
+ state_dict ['optimizer' ] = {k : v for k , v in self .optimizer .state_dict ().items () if k != ' state' }
375
376
376
- for param_group in state_dict [" optimizer" ][ " param_groups" ]:
377
- del param_group [" params" ]
377
+ for param_group in state_dict [' optimizer' ][ ' param_groups' ]:
378
+ del param_group [' params' ]
378
379
379
380
# Grad scaler state.
380
381
if self .grad_scaler :
@@ -421,30 +422,28 @@ def load_state_dict(self, state_dict):
421
422
state_dict_param_groups = [
422
423
{
423
424
** group ,
424
- " params" : list (inner_state_dict [" param_groups" ][idx ][" params" ]),
425
- } for idx , group in enumerate (state_dict [" optimizer" ][ " param_groups" ])
425
+ ' params' : list (inner_state_dict [' param_groups' ][idx ][' params' ]),
426
+ } for idx , group in enumerate (state_dict [' optimizer' ][ ' param_groups' ])
426
427
]
427
428
428
429
# Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
429
430
# - Real data is overwritten during load_parameter_state().
430
431
state_dict_state = []
431
432
for gbuf_range_maps in self .model_gbuf_ranges :
432
433
for gbuf_range_map in gbuf_range_maps .values ():
433
- for model_param , param_range_map in \
434
- gbuf_range_map ["param_map" ].items ():
434
+ for model_param , param_range_map in gbuf_range_map ['param_map' ].items ():
435
435
436
436
# Get parameter ordering information (see method docstring
437
437
# for details).
438
438
group_index , group_order = \
439
439
self .model_param_group_index_map [model_param ]
440
- state_order = inner_state_dict ["param_groups" ] \
441
- [group_index ]["params" ][group_order ]
440
+ state_order = inner_state_dict ['param_groups' ][group_index ]['params' ][group_order ]
442
441
443
442
# Allocate dummy tensors.
444
- numel = len (param_range_map [" gbuf_world" ])
443
+ numel = len (param_range_map [' gbuf_world' ])
445
444
# MS-AMP: Allocate dummy tensors for exp_avg and exp_avg_sq and cast to ScalingTensor
446
445
if hasattr (self .optimizer , 'exp_avg_dtype' ) and self .optimizer .exp_avg_dtype != torch .float32 :
447
- step = state_dict ['optimizer' ][" param_groups" ][group_index ][" step" ]
446
+ step = state_dict ['optimizer' ][' param_groups' ][group_index ][' step' ]
448
447
exp_avg_qtype = Dtypes .dtype_to_qtype [self .optimizer .exp_avg_dtype ]
449
448
exp_avg_sq_qtype = Dtypes .dtype_to_qtype [self .optimizer .exp_avg_sq_dtype ]
450
449
exp_avg = torch .empty ((numel , ), dtype = torch .float32 ,
@@ -453,19 +452,19 @@ def load_state_dict(self, state_dict):
453
452
device = torch .cuda .current_device ()).cast (exp_avg_sq_qtype )
454
453
state_dict_state .append (
455
454
(state_order , {
456
- " exp_avg" : exp_avg ,
457
- " exp_avg_sq" : exp_avg_sq ,
458
- " step" : step
455
+ ' exp_avg' : exp_avg ,
456
+ ' exp_avg_sq' : exp_avg_sq ,
457
+ ' step' : step
459
458
})
460
459
)
461
460
else :
462
- init_shard = lambda : torch .empty (
461
+ init_shard = lambda : torch .empty ( # noqa: E731
463
462
(numel , ), dtype = torch .float32 , device = torch .cuda .current_device ()
464
463
)
465
464
466
465
state_dict_state .append ((state_order , {
467
- " exp_avg" : init_shard (),
468
- " exp_avg_sq" : init_shard (),
466
+ ' exp_avg' : init_shard (),
467
+ ' exp_avg_sq' : init_shard (),
469
468
}))
470
469
471
470
# Sort by state order (see method docstring for details).
@@ -474,8 +473,8 @@ def load_state_dict(self, state_dict):
474
473
475
474
# Optimizer.
476
475
self .optimizer .load_state_dict ({
477
- " state" : state_dict_state ,
478
- " param_groups" : state_dict_param_groups ,
476
+ ' state' : state_dict_state ,
477
+ ' param_groups' : state_dict_param_groups ,
479
478
})
480
479
481
480
# Grad scaler.
@@ -528,29 +527,26 @@ def save_parameter_state(self, filename):
528
527
gbuf_world_numel = model ._grad_buffers [dtype ].numel_padded
529
528
gbuf_local_numel = int (gbuf_world_numel / data_parallel_world_size )
530
529
local_shards = {
531
- key : torch .empty ((gbuf_local_numel , ), dtype = torch .float32 , device = " cpu" )
532
- for key in (" param" , " exp_avg" , " exp_avg_sq" )
530
+ key : torch .empty ((gbuf_local_numel , ), dtype = torch .float32 , device = ' cpu' )
531
+ for key in (' param' , ' exp_avg' , ' exp_avg_sq' )
533
532
}
534
533
535
534
# Build contiguous DP rank shards (for param + optim states).
536
- for model_param , param_range_map in \
537
- gbuf_range_map ["param_map" ].items ():
535
+ for model_param , param_range_map in gbuf_range_map ['param_map' ].items ():
538
536
539
537
# Main param & optimizer states.
540
- group_index , group_order = \
541
- self .model_param_group_index_map [model_param ]
542
- main_param = self .optimizer .param_groups \
543
- [group_index ]["params" ][group_order ]
538
+ group_index , group_order = self .model_param_group_index_map [model_param ]
539
+ main_param = self .optimizer .param_groups [group_index ]['params' ][group_order ]
544
540
optim_state = self .optimizer .state [main_param ]
545
541
546
542
tensors = {
547
- " param" : main_param ,
543
+ ' param' : main_param ,
548
544
** optim_state ,
549
545
}
550
546
551
547
# Copy states into contiguous shard.
552
- gbuf_local_start = param_range_map [" gbuf_local" ].start
553
- gbuf_local_end = param_range_map [" gbuf_local" ].end
548
+ gbuf_local_start = param_range_map [' gbuf_local' ].start
549
+ gbuf_local_end = param_range_map [' gbuf_local' ].end
554
550
for key in local_shards :
555
551
# MS-AMP: Convert to float32 for ScalingTensor.
556
552
if isinstance (tensors [key ], ScalingTensor ):
@@ -567,7 +563,7 @@ def save_parameter_state(self, filename):
567
563
# Gather tensor list.
568
564
if data_parallel_rank == 0 :
569
565
recv_tensors = [
570
- torch .empty ((gbuf_local_numel , ), dtype = torch .float32 , device = " cpu" )
566
+ torch .empty ((gbuf_local_numel , ), dtype = torch .float32 , device = ' cpu' )
571
567
for _ in range (data_parallel_world_size )
572
568
]
573
569
else :
@@ -626,8 +622,8 @@ def load_parameter_state(self, filename):
626
622
627
623
# Contiguous local shards (received from DP rank 0).
628
624
local_shards = {
629
- key : torch .empty ((gbuf_local_numel , ), dtype = torch .float32 , device = " cpu" )
630
- for key in (" param" , " exp_avg" , " exp_avg_sq" )
625
+ key : torch .empty ((gbuf_local_numel , ), dtype = torch .float32 , device = ' cpu' )
626
+ for key in (' param' , ' exp_avg' , ' exp_avg_sq' )
631
627
}
632
628
633
629
# Scatter local shards from DP rank 0.
@@ -651,24 +647,22 @@ def load_parameter_state(self, filename):
651
647
)
652
648
653
649
# Copy local contiguous shards to param/optim shards.
654
- for model_param , param_range_map in \
655
- gbuf_range_map ["param_map" ].items ():
650
+ for model_param , param_range_map in gbuf_range_map ['param_map' ].items ():
656
651
657
652
# Main param & optimizer states.
658
653
group_index , group_order = \
659
654
self .model_param_group_index_map [model_param ]
660
- main_param = self .optimizer .param_groups \
661
- [group_index ]["params" ][group_order ]
655
+ main_param = self .optimizer .param_groups [group_index ]['params' ][group_order ]
662
656
optim_state = self .optimizer .state [main_param ]
663
657
664
658
tensors = {
665
- " param" : main_param ,
659
+ ' param' : main_param ,
666
660
** optim_state ,
667
661
}
668
662
669
663
# Copy states into contiguous shard.
670
- gbuf_local_start = param_range_map [" gbuf_local" ].start
671
- gbuf_local_end = param_range_map [" gbuf_local" ].end
664
+ gbuf_local_start = param_range_map [' gbuf_local' ].start
665
+ gbuf_local_end = param_range_map [' gbuf_local' ].end
672
666
for key in local_shards :
673
667
if isinstance (tensors [key ], ScalingTensor ):
674
668
tensors [key ].copy_ (
0 commit comments