Skip to content

Commit

Permalink
fix lint issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Feb 28, 2024
1 parent d82f953 commit bbff5b4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 43 deletions.
76 changes: 35 additions & 41 deletions msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def get_model_parallel_group(self):
return None

def state_dict(self):
"""
"""Return the state dict of this optimizer.
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
related) optimizer variables. The returned state dict can be stored in
the standard model/RNG checkpoint file. The parameter and dependent
Expand All @@ -371,10 +372,10 @@ def state_dict(self):
state_dict = {}

# Optimizer state (do not store parameter state here).
state_dict['optimizer'] = {k: v for k, v in self.optimizer.state_dict().items() if k != "state"}
state_dict['optimizer'] = {k: v for k, v in self.optimizer.state_dict().items() if k != 'state'}

for param_group in state_dict["optimizer"]["param_groups"]:
del param_group["params"]
for param_group in state_dict['optimizer']['param_groups']:
del param_group['params']

# Grad scaler state.
if self.grad_scaler:
Expand Down Expand Up @@ -421,30 +422,28 @@ def load_state_dict(self, state_dict):
state_dict_param_groups = [
{
**group,
"params": list(inner_state_dict["param_groups"][idx]["params"]),
} for idx, group in enumerate(state_dict["optimizer"]["param_groups"])
'params': list(inner_state_dict['param_groups'][idx]['params']),
} for idx, group in enumerate(state_dict['optimizer']['param_groups'])
]

# Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
# - Real data is overwritten during load_parameter_state().
state_dict_state = []
for gbuf_range_maps in self.model_gbuf_ranges:
for gbuf_range_map in gbuf_range_maps.values():
for model_param, param_range_map in \
gbuf_range_map["param_map"].items():
for model_param, param_range_map in gbuf_range_map['param_map'].items():

# Get parameter ordering information (see method docstring
# for details).
group_index, group_order = \
self.model_param_group_index_map[model_param]
state_order = inner_state_dict["param_groups"] \
[group_index]["params"][group_order]
state_order = inner_state_dict['param_groups'][group_index]['params'][group_order]

# Allocate dummy tensors.
numel = len(param_range_map["gbuf_world"])
numel = len(param_range_map['gbuf_world'])
# MS-AMP: Allocate dummy tensors for exp_avg and exp_avg_sq and cast to ScalingTensor
if hasattr(self.optimizer, 'exp_avg_dtype') and self.optimizer.exp_avg_dtype != torch.float32:
step = state_dict['optimizer']["param_groups"][group_index]["step"]
step = state_dict['optimizer']['param_groups'][group_index]['step']
exp_avg_qtype = Dtypes.dtype_to_qtype[self.optimizer.exp_avg_dtype]
exp_avg_sq_qtype = Dtypes.dtype_to_qtype[self.optimizer.exp_avg_sq_dtype]
exp_avg = torch.empty((numel, ), dtype=torch.float32,
Expand All @@ -453,19 +452,19 @@ def load_state_dict(self, state_dict):
device=torch.cuda.current_device()).cast(exp_avg_sq_qtype)
state_dict_state.append(
(state_order, {
"exp_avg": exp_avg,
"exp_avg_sq": exp_avg_sq,
"step": step
'exp_avg': exp_avg,
'exp_avg_sq': exp_avg_sq,
'step': step
})
)
else:
init_shard = lambda: torch.empty(
init_shard = lambda: torch.empty( # noqa: E731
(numel, ), dtype=torch.float32, device=torch.cuda.current_device()
)

state_dict_state.append((state_order, {
"exp_avg": init_shard(),
"exp_avg_sq": init_shard(),
'exp_avg': init_shard(),
'exp_avg_sq': init_shard(),
}))

# Sort by state order (see method docstring for details).
Expand All @@ -474,8 +473,8 @@ def load_state_dict(self, state_dict):

# Optimizer.
self.optimizer.load_state_dict({
"state": state_dict_state,
"param_groups": state_dict_param_groups,
'state': state_dict_state,
'param_groups': state_dict_param_groups,
})

# Grad scaler.
Expand Down Expand Up @@ -528,29 +527,26 @@ def save_parameter_state(self, filename):
gbuf_world_numel = model._grad_buffers[dtype].numel_padded
gbuf_local_numel = int(gbuf_world_numel / data_parallel_world_size)
local_shards = {
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu')
for key in ('param', 'exp_avg', 'exp_avg_sq')
}

# Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in \
gbuf_range_map["param_map"].items():
for model_param, param_range_map in gbuf_range_map['param_map'].items():

# Main param & optimizer states.
group_index, group_order = \
self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups \
[group_index]["params"][group_order]
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]['params'][group_order]
optim_state = self.optimizer.state[main_param]

tensors = {
"param": main_param,
'param': main_param,
**optim_state,
}

# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
gbuf_local_start = param_range_map['gbuf_local'].start
gbuf_local_end = param_range_map['gbuf_local'].end
for key in local_shards:
# MS-AMP: Convert to float32 for ScalingTensor.
if isinstance(tensors[key], ScalingTensor):
Expand All @@ -567,7 +563,7 @@ def save_parameter_state(self, filename):
# Gather tensor list.
if data_parallel_rank == 0:
recv_tensors = [
torch.empty((gbuf_local_numel, ), dtype=torch.float32, device="cpu")
torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu')
for _ in range(data_parallel_world_size)
]
else:
Expand Down Expand Up @@ -626,8 +622,8 @@ def load_parameter_state(self, filename):

# Contiguous local shards (received from DP rank 0).
local_shards = {
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device="cpu")
for key in ("param", "exp_avg", "exp_avg_sq")
key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu')
for key in ('param', 'exp_avg', 'exp_avg_sq')
}

# Scatter local shards from DP rank 0.
Expand All @@ -651,24 +647,22 @@ def load_parameter_state(self, filename):
)

# Copy local contiguous shards to param/optim shards.
for model_param, param_range_map in \
gbuf_range_map["param_map"].items():
for model_param, param_range_map in gbuf_range_map['param_map'].items():

# Main param & optimizer states.
group_index, group_order = \
self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups \
[group_index]["params"][group_order]
main_param = self.optimizer.param_groups[group_index]['params'][group_order]
optim_state = self.optimizer.state[main_param]

tensors = {
"param": main_param,
'param': main_param,
**optim_state,
}

# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
gbuf_local_start = param_range_map['gbuf_local'].start
gbuf_local_end = param_range_map['gbuf_local'].end
for key in local_shards:
if isinstance(tensors[key], ScalingTensor):
tensors[key].copy_(
Expand Down
12 changes: 10 additions & 2 deletions msamp/te/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,16 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)

@staticmethod
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
def cast_if_needed(tensor, dtype):
"""Cast tensor to dtype.
Args:
tensor (torch.Tensor or ScalingParameter): Input tensor.
dtype (torch.dtype): Output dtype.
Returns:
torch.Tensor: Output tensor.
"""
with torch.enable_grad():
if isinstance(tensor, ScalingParameter):
new_tensor = tensor.to(dtype)
Expand Down

0 comments on commit bbff5b4

Please sign in to comment.