From 71762193711a5afe6cd4877306a2583b9d79ba59 Mon Sep 17 00:00:00 2001 From: tocean Date: Fri, 24 Nov 2023 11:56:21 +0000 Subject: [PATCH 1/7] fix bug for none grad in deepspeed --- msamp/deepspeed/runtime/engine.py | 38 ++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/msamp/deepspeed/runtime/engine.py b/msamp/deepspeed/runtime/engine.py index c081513f..0209bb7b 100644 --- a/msamp/deepspeed/runtime/engine.py +++ b/msamp/deepspeed/runtime/engine.py @@ -4,14 +4,17 @@ # SPDX-License-Identifier: Apache-2.0. DeepSpeed Team) """DeepSpeedEngine in MS-AMP.""" +import torch import deepspeed from deepspeed.runtime.engine import SparseTensor, ZERO_OPTIMIZATION, AMP, amp, \ FP16, BFLOAT16, logger, DeepSpeedEngine, instrument_w_nvtx, log_dist, \ see_memory_usage, DummyOptim, DeepSpeedZeroOptimizer, DeepSpeedZeRoOffload, \ PipelineModule, ZeroStageEnum +from deepspeed.moe.utils import is_moe_param from msamp import initialize as msamp_initialize -from msamp.common.tensor import ScalingTensor, TensorDist +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingTensor, TensorDist, ScalingMeta from msamp.optim import LBOptimizer from msamp.deepspeed.runtime.fp8.fused_optimizer import FP8Optimizer from msamp.deepspeed.runtime.zero import utils # noqa: F401 @@ -301,6 +304,38 @@ def _configure_zero_optimizer(self, optimizer): return optimizer + def _get_gradients_for_reduction(self): + non_expert_grads = [] + expert_grads = {} + if self.has_moe_layers: + for key in self.expert_data_parallel_group.keys(): + expert_grads[key] = [] + + for param_name, param in self.module.named_parameters(): + if param.grad is None: + # In cases where there is an imbalance of empty grads across + # ranks we must create empty grads, this will ensure that every + # rank is reducing the same size. In some cases it may make + # sense in the future to support the ability to average not + # w.r.t. world size but with a different value. + if isinstance(param, ScalingTensor): + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + param.grad = ScalingTensor(torch.zeros(param.size(), dtype=param.dtype, device=param.device), meta) + else: + param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) + + grad_data = param.grad.data + if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: + # Call param.grad without data to avoid problem with setting of updated grads + grad_data = SparseTensor(param.grad) + + if is_moe_param(param): + expert_grads[param.group_name].append(grad_data) + else: + non_expert_grads.append(grad_data) + + return non_expert_grads, expert_grads + @instrument_w_nvtx def backward( # noqa: C901 self, @@ -434,3 +469,4 @@ def msamp_enabled(self): def msamp_optlevel(self): """Return the opt level of MS-AMP.""" return self._config.msamp_optlevel + From 18d9c176ada74700ba2fcf68d0d467caa25c10a9 Mon Sep 17 00:00:00 2001 From: tocean Date: Mon, 27 Nov 2023 05:43:20 +0000 Subject: [PATCH 2/7] fix bug of empty param group for high precision parameters in deepspeed --- msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py index 25d65457..ebe0ebe2 100644 --- a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py +++ b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py @@ -45,7 +45,11 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901 else: hp_params.append(p) self.fp8_param_groups.append(fp8_params) - pg['params'] = hp_params + + if len(hp_params) == 0: + init_optimizer.param_groups.remove(pg) + else: + pg['params'] = hp_params assert len(self.fp8_param_groups) == len(init_optimizer.param_groups) From 8211b5dc67e863f7b3e66d3b689b8399efb42454 Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 29 Nov 2023 12:02:36 +0000 Subject: [PATCH 3/7] add ut --- msamp/deepspeed/runtime/engine.py | 3 +-- msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py | 5 +---- tests/deepspeed/test_engine.py | 4 ++++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/msamp/deepspeed/runtime/engine.py b/msamp/deepspeed/runtime/engine.py index 0209bb7b..11f4999c 100644 --- a/msamp/deepspeed/runtime/engine.py +++ b/msamp/deepspeed/runtime/engine.py @@ -319,7 +319,7 @@ def _get_gradients_for_reduction(self): # sense in the future to support the ability to average not # w.r.t. world size but with a different value. if isinstance(param, ScalingTensor): - meta = ScalingMeta(Dtypes.kfloat8_e4m3) + meta = ScalingMeta(Dtypes.dtype_to_qtype[param.dtype]) param.grad = ScalingTensor(torch.zeros(param.size(), dtype=param.dtype, device=param.device), meta) else: param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) @@ -469,4 +469,3 @@ def msamp_enabled(self): def msamp_optlevel(self): """Return the opt level of MS-AMP.""" return self._config.msamp_optlevel - diff --git a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py index ebe0ebe2..514344ea 100644 --- a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py +++ b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py @@ -46,10 +46,7 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901 hp_params.append(p) self.fp8_param_groups.append(fp8_params) - if len(hp_params) == 0: - init_optimizer.param_groups.remove(pg) - else: - pg['params'] = hp_params + pg['params'] = hp_params assert len(self.fp8_param_groups) == len(init_optimizer.param_groups) diff --git a/tests/deepspeed/test_engine.py b/tests/deepspeed/test_engine.py index 7c4706ca..193b3183 100644 --- a/tests/deepspeed/test_engine.py +++ b/tests/deepspeed/test_engine.py @@ -174,6 +174,10 @@ def test_backward(self): } model, _, _, _ = deepspeed.initialize(model=model, config=config) + for name, param in model.module.named_parameters(): + if name.startswith('1.'): + param.requires_grad = False + inputs = [] num_inputs = 10 for _ in range(num_inputs): From 053cdc436d92d8dd8fb6699aba0f8da046c0976d Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 29 Nov 2023 12:04:35 +0000 Subject: [PATCH 4/7] fix comments --- msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py index 514344ea..25d65457 100644 --- a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py +++ b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py @@ -45,7 +45,6 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901 else: hp_params.append(p) self.fp8_param_groups.append(fp8_params) - pg['params'] = hp_params assert len(self.fp8_param_groups) == len(init_optimizer.param_groups) From a3830cc189d3e9ca4546c4b84e69d70f710d0767 Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 29 Nov 2023 12:06:39 +0000 Subject: [PATCH 5/7] fix lint --- msamp/deepspeed/runtime/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/msamp/deepspeed/runtime/engine.py b/msamp/deepspeed/runtime/engine.py index 11f4999c..22ff9bfc 100644 --- a/msamp/deepspeed/runtime/engine.py +++ b/msamp/deepspeed/runtime/engine.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0. DeepSpeed Team) """DeepSpeedEngine in MS-AMP.""" + import torch import deepspeed from deepspeed.runtime.engine import SparseTensor, ZERO_OPTIMIZATION, AMP, amp, \ From 6ae338347daa97213ca999e5e811caf58746c14b Mon Sep 17 00:00:00 2001 From: tocean Date: Thu, 30 Nov 2023 03:41:58 +0000 Subject: [PATCH 6/7] add is_meta in scalingtensor --- msamp/common/tensor/tensor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index d355f0db..a173f025 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -417,6 +417,15 @@ def shape(self): """ return self.value.shape + @property + def is_meta(self): + """Return is_meta property of tensor. + + Return: + bool: the is_meta property of value tensor. + """ + return self.value.is_meta + @property def size(self): """Return size function of tensor. From 70f51473e49a006bfcfb033547acd961be3aaec1 Mon Sep 17 00:00:00 2001 From: tocean Date: Thu, 30 Nov 2023 03:55:11 +0000 Subject: [PATCH 7/7] fix lint issue --- msamp/deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msamp/deepspeed/runtime/engine.py b/msamp/deepspeed/runtime/engine.py index 22ff9bfc..67acdcdf 100644 --- a/msamp/deepspeed/runtime/engine.py +++ b/msamp/deepspeed/runtime/engine.py @@ -336,7 +336,7 @@ def _get_gradients_for_reduction(self): non_expert_grads.append(grad_data) return non_expert_grads, expert_grads - + @instrument_w_nvtx def backward( # noqa: C901 self,