diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index d355f0db..a173f025 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -417,6 +417,15 @@ def shape(self): """ return self.value.shape + @property + def is_meta(self): + """Return is_meta property of tensor. + + Return: + bool: the is_meta property of value tensor. + """ + return self.value.is_meta + @property def size(self): """Return size function of tensor. diff --git a/msamp/deepspeed/runtime/engine.py b/msamp/deepspeed/runtime/engine.py index c081513f..67acdcdf 100644 --- a/msamp/deepspeed/runtime/engine.py +++ b/msamp/deepspeed/runtime/engine.py @@ -4,14 +4,18 @@ # SPDX-License-Identifier: Apache-2.0. DeepSpeed Team) """DeepSpeedEngine in MS-AMP.""" + +import torch import deepspeed from deepspeed.runtime.engine import SparseTensor, ZERO_OPTIMIZATION, AMP, amp, \ FP16, BFLOAT16, logger, DeepSpeedEngine, instrument_w_nvtx, log_dist, \ see_memory_usage, DummyOptim, DeepSpeedZeroOptimizer, DeepSpeedZeRoOffload, \ PipelineModule, ZeroStageEnum +from deepspeed.moe.utils import is_moe_param from msamp import initialize as msamp_initialize -from msamp.common.tensor import ScalingTensor, TensorDist +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingTensor, TensorDist, ScalingMeta from msamp.optim import LBOptimizer from msamp.deepspeed.runtime.fp8.fused_optimizer import FP8Optimizer from msamp.deepspeed.runtime.zero import utils # noqa: F401 @@ -301,6 +305,38 @@ def _configure_zero_optimizer(self, optimizer): return optimizer + def _get_gradients_for_reduction(self): + non_expert_grads = [] + expert_grads = {} + if self.has_moe_layers: + for key in self.expert_data_parallel_group.keys(): + expert_grads[key] = [] + + for param_name, param in self.module.named_parameters(): + if param.grad is None: + # In cases where there is an imbalance of empty grads across + # ranks we must create empty grads, this will ensure that every + # rank is reducing the same size. In some cases it may make + # sense in the future to support the ability to average not + # w.r.t. world size but with a different value. + if isinstance(param, ScalingTensor): + meta = ScalingMeta(Dtypes.dtype_to_qtype[param.dtype]) + param.grad = ScalingTensor(torch.zeros(param.size(), dtype=param.dtype, device=param.device), meta) + else: + param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) + + grad_data = param.grad.data + if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: + # Call param.grad without data to avoid problem with setting of updated grads + grad_data = SparseTensor(param.grad) + + if is_moe_param(param): + expert_grads[param.group_name].append(grad_data) + else: + non_expert_grads.append(grad_data) + + return non_expert_grads, expert_grads + @instrument_w_nvtx def backward( # noqa: C901 self, diff --git a/tests/deepspeed/test_engine.py b/tests/deepspeed/test_engine.py index 7c4706ca..193b3183 100644 --- a/tests/deepspeed/test_engine.py +++ b/tests/deepspeed/test_engine.py @@ -174,6 +174,10 @@ def test_backward(self): } model, _, _, _ = deepspeed.initialize(model=model, config=config) + for name, param in model.module.named_parameters(): + if name.startswith('1.'): + param.requires_grad = False + inputs = [] num_inputs = 10 for _ in range(num_inputs):