Skip to content

Commit bbff5b4

Browse files
committed
fix lint issue
1 parent d82f953 commit bbff5b4

File tree

2 files changed

+45
-43
lines changed

2 files changed

+45
-43
lines changed

msamp/megatron/optimizer/distrib_optimizer.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ def get_model_parallel_group(self):
351351
return None
352352

353353
def state_dict(self):
354-
"""
354+
"""Return the state dict of this optimizer.
355+
355356
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
356357
related) optimizer variables. The returned state dict can be stored in
357358
the standard model/RNG checkpoint file. The parameter and dependent
@@ -371,10 +372,10 @@ def state_dict(self):
371372
state_dict = {}
372373

373374
# Optimizer state (do not store parameter state here).
374-
state_dict['optimizer'] = {k: v for k, v in self.optimizer.state_dict().items() if k != "state"}
375+
state_dict['optimizer'] = {k: v for k, v in self.optimizer.state_dict().items() if k != 'state'}
375376

376-
for param_group in state_dict["optimizer"]["param_groups"]:
377-
del param_group["params"]
377+
for param_group in state_dict['optimizer']['param_groups']:
378+
del param_group['params']
378379

379380
# Grad scaler state.
380381
if self.grad_scaler:
@@ -421,30 +422,28 @@ def load_state_dict(self, state_dict):
421422
state_dict_param_groups = [
422423
{
423424
**group,
424-
"params": list(inner_state_dict["param_groups"][idx]["params"]),
425-
} for idx, group in enumerate(state_dict["optimizer"]["param_groups"])
425+
'params': list(inner_state_dict['param_groups'][idx]['params']),
426+
} for idx, group in enumerate(state_dict['optimizer']['param_groups'])
426427
]
427428

428429
# Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
429430
# - Real data is overwritten during load_parameter_state().
430431
state_dict_state = []
431432
for gbuf_range_maps in self.model_gbuf_ranges:
432433
for gbuf_range_map in gbuf_range_maps.values():
433-
for model_param, param_range_map in \
434-
gbuf_range_map["param_map"].items():
434+
for model_param, param_range_map in gbuf_range_map['param_map'].items():
435435

436436
# Get parameter ordering information (see method docstring
437437
# for details).
438438
group_index, group_order = \
439439
self.model_param_group_index_map[model_param]
440-
state_order = inner_state_dict["param_groups"] \
441-
[group_index]["params"][group_order]
440+
state_order = inner_state_dict['param_groups'][group_index]['params'][group_order]
442441

443442
# Allocate dummy tensors.
444-
numel = len(param_range_map["gbuf_world"])
443+
numel = len(param_range_map['gbuf_world'])
445444
# MS-AMP: Allocate dummy tensors for exp_avg and exp_avg_sq and cast to ScalingTensor
446445
if hasattr(self.optimizer, 'exp_avg_dtype') and self.optimizer.exp_avg_dtype != torch.float32:
447-
step = state_dict['optimizer']["param_groups"][group_index]["step"]
446+
step = state_dict['optimizer']['param_groups'][group_index]['step']
448447
exp_avg_qtype = Dtypes.dtype_to_qtype[self.optimizer.exp_avg_dtype]
449448
exp_avg_sq_qtype = Dtypes.dtype_to_qtype[self.optimizer.exp_avg_sq_dtype]
450449
exp_avg = torch.empty((numel, ), dtype=torch.float32,
@@ -453,19 +452,19 @@ def load_state_dict(self, state_dict):
453452
device=torch.cuda.current_device()).cast(exp_avg_sq_qtype)
454453
state_dict_state.append(
455454
(state_order, {
456-
"exp_avg": exp_avg,
457-
"exp_avg_sq": exp_avg_sq,
458-
"step": step
455+
'exp_avg': exp_avg,
456+
'exp_avg_sq': exp_avg_sq,
457+
'step': step
459458
})
460459
)
461460
else:
462-
init_shard = lambda: torch.empty(
461+
init_shard = lambda: torch.empty( # noqa: E731
463462
(numel, ), dtype=torch.float32, device=torch.cuda.current_device()
464463
)
465464

466465
state_dict_state.append((state_order, {
467-
"exp_avg": init_shard(),
468-
"exp_avg_sq": init_shard(),
466+
'exp_avg': init_shard(),
467+
'exp_avg_sq': init_shard(),
469468
}))
470469

471470
# Sort by state order (see method docstring for details).
@@ -474,8 +473,8 @@ def load_state_dict(self, state_dict):
474473

475474
# Optimizer.
476475
self.optimizer.load_state_dict({
477-
"state": state_dict_state,
478-
"param_groups": state_dict_param_groups,
476+
'state': state_dict_state,
477+
'param_groups': state_dict_param_groups,
479478
})
480479

481480
# Grad scaler.
@@ -528,29 +527,26 @@ def save_parameter_state(self, filename):
528527
gbuf_world_numel = model._grad_buffers[dtype].numel_padded
529528
gbuf_local_numel = int(gbuf_world_numel / data_parallel_world_size)
530529
local_shards = {
531-
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device="cpu")
532-
for key in ("param", "exp_avg", "exp_avg_sq")
530+
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu')
531+
for key in ('param', 'exp_avg', 'exp_avg_sq')
533532
}
534533

535534
# Build contiguous DP rank shards (for param + optim states).
536-
for model_param, param_range_map in \
537-
gbuf_range_map["param_map"].items():
535+
for model_param, param_range_map in gbuf_range_map['param_map'].items():
538536

539537
# Main param & optimizer states.
540-
group_index, group_order = \
541-
self.model_param_group_index_map[model_param]
542-
main_param = self.optimizer.param_groups \
543-
[group_index]["params"][group_order]
538+
group_index, group_order = self.model_param_group_index_map[model_param]
539+
main_param = self.optimizer.param_groups[group_index]['params'][group_order]
544540
optim_state = self.optimizer.state[main_param]
545541

546542
tensors = {
547-
"param": main_param,
543+
'param': main_param,
548544
**optim_state,
549545
}
550546

551547
# Copy states into contiguous shard.
552-
gbuf_local_start = param_range_map["gbuf_local"].start
553-
gbuf_local_end = param_range_map["gbuf_local"].end
548+
gbuf_local_start = param_range_map['gbuf_local'].start
549+
gbuf_local_end = param_range_map['gbuf_local'].end
554550
for key in local_shards:
555551
# MS-AMP: Convert to float32 for ScalingTensor.
556552
if isinstance(tensors[key], ScalingTensor):
@@ -567,7 +563,7 @@ def save_parameter_state(self, filename):
567563
# Gather tensor list.
568564
if data_parallel_rank == 0:
569565
recv_tensors = [
570-
torch.empty((gbuf_local_numel, ), dtype=torch.float32, device="cpu")
566+
torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu')
571567
for _ in range(data_parallel_world_size)
572568
]
573569
else:
@@ -626,8 +622,8 @@ def load_parameter_state(self, filename):
626622

627623
# Contiguous local shards (received from DP rank 0).
628624
local_shards = {
629-
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device="cpu")
630-
for key in ("param", "exp_avg", "exp_avg_sq")
625+
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu')
626+
for key in ('param', 'exp_avg', 'exp_avg_sq')
631627
}
632628

633629
# Scatter local shards from DP rank 0.
@@ -651,24 +647,22 @@ def load_parameter_state(self, filename):
651647
)
652648

653649
# Copy local contiguous shards to param/optim shards.
654-
for model_param, param_range_map in \
655-
gbuf_range_map["param_map"].items():
650+
for model_param, param_range_map in gbuf_range_map['param_map'].items():
656651

657652
# Main param & optimizer states.
658653
group_index, group_order = \
659654
self.model_param_group_index_map[model_param]
660-
main_param = self.optimizer.param_groups \
661-
[group_index]["params"][group_order]
655+
main_param = self.optimizer.param_groups[group_index]['params'][group_order]
662656
optim_state = self.optimizer.state[main_param]
663657

664658
tensors = {
665-
"param": main_param,
659+
'param': main_param,
666660
**optim_state,
667661
}
668662

669663
# Copy states into contiguous shard.
670-
gbuf_local_start = param_range_map["gbuf_local"].start
671-
gbuf_local_end = param_range_map["gbuf_local"].end
664+
gbuf_local_start = param_range_map['gbuf_local'].start
665+
gbuf_local_end = param_range_map['gbuf_local'].end
672666
for key in local_shards:
673667
if isinstance(tensors[key], ScalingTensor):
674668
tensors[key].copy_(

msamp/te/extension.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,16 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
122122
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)
123123

124124
@staticmethod
125-
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
126-
"""Cast tensor to dtype"""
125+
def cast_if_needed(tensor, dtype):
126+
"""Cast tensor to dtype.
127+
128+
Args:
129+
tensor (torch.Tensor or ScalingParameter): Input tensor.
130+
dtype (torch.dtype): Output dtype.
131+
132+
Returns:
133+
torch.Tensor: Output tensor.
134+
"""
127135
with torch.enable_grad():
128136
if isinstance(tensor, ScalingParameter):
129137
new_tensor = tensor.to(dtype)

0 commit comments

Comments
 (0)