Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] when parameters has no grad or ScalingParameter has no is_meta property it will crash #135

Merged
merged 7 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions msamp/common/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,15 @@ def shape(self):
"""
return self.value.shape

@property
def is_meta(self):
"""Return is_meta property of tensor.

Return:
bool: the is_meta property of value tensor.
"""
return self.value.is_meta

@property
def size(self):
"""Return size function of tensor.
Expand Down
38 changes: 37 additions & 1 deletion msamp/deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# SPDX-License-Identifier: Apache-2.0. DeepSpeed Team)

"""DeepSpeedEngine in MS-AMP."""

import torch
import deepspeed
from deepspeed.runtime.engine import SparseTensor, ZERO_OPTIMIZATION, AMP, amp, \
FP16, BFLOAT16, logger, DeepSpeedEngine, instrument_w_nvtx, log_dist, \
see_memory_usage, DummyOptim, DeepSpeedZeroOptimizer, DeepSpeedZeRoOffload, \
PipelineModule, ZeroStageEnum
from deepspeed.moe.utils import is_moe_param

from msamp import initialize as msamp_initialize
from msamp.common.tensor import ScalingTensor, TensorDist
from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingTensor, TensorDist, ScalingMeta
from msamp.optim import LBOptimizer
from msamp.deepspeed.runtime.fp8.fused_optimizer import FP8Optimizer
from msamp.deepspeed.runtime.zero import utils # noqa: F401
Expand Down Expand Up @@ -301,6 +305,38 @@ def _configure_zero_optimizer(self, optimizer):

return optimizer

def _get_gradients_for_reduction(self):
non_expert_grads = []
expert_grads = {}
if self.has_moe_layers:
for key in self.expert_data_parallel_group.keys():
expert_grads[key] = []

for param_name, param in self.module.named_parameters():
if param.grad is None:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
if isinstance(param, ScalingTensor):
meta = ScalingMeta(Dtypes.dtype_to_qtype[param.dtype])
param.grad = ScalingTensor(torch.zeros(param.size(), dtype=param.dtype, device=param.device), meta)
else:
param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)

grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
# Call param.grad without data to avoid problem with setting of updated grads
grad_data = SparseTensor(param.grad)

if is_moe_param(param):
expert_grads[param.group_name].append(grad_data)
else:
non_expert_grads.append(grad_data)

return non_expert_grads, expert_grads

@instrument_w_nvtx
def backward( # noqa: C901
self,
Expand Down
4 changes: 4 additions & 0 deletions tests/deepspeed/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def test_backward(self):
}
model, _, _, _ = deepspeed.initialize(model=model, config=config)

for name, param in model.module.named_parameters():
if name.startswith('1.'):
param.requires_grad = False

inputs = []
num_inputs = 10
for _ in range(num_inputs):
Expand Down
Loading