diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b3adfb7dbf..ec05f684b8 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -402,7 +402,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias if grad_biases[0] is None else None, bias=biases, use_split_accumulator=wgrad_gemm_use_split_accumulator, - accumulate=accumulate_wgrad_into_param_main_grad, + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): @@ -519,7 +523,9 @@ class GroupedLinear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e1c0eab2dc..0559ae7cec 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -849,7 +849,11 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None), @@ -1125,7 +1129,9 @@ class LayerNormLinear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d680a9f8f6..8ef19d0520 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -948,7 +948,11 @@ def backward( else ctx.activation_dtype ), "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(fc1_weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, @@ -1189,7 +1193,11 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ctx.fc1_grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(fc2_weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, @@ -1484,7 +1492,9 @@ class LayerNormMLP(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 02872439a3..67124c1570 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -843,7 +843,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None), @@ -1061,7 +1065,9 @@ class Linear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb2119296f..b15d840d63 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -80,7 +80,9 @@ class BasicLinear(BasicOperation): autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be meaningful. This is primarily intented to integrate with - Megatron-LM. + Megatron-LM. This argument along with weight tensor having + attribute 'overwrite_main_grad' set to True will overwrite + `main_grad` instead of accumulating. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -1019,6 +1021,7 @@ def op_backward( weight_param = self.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 845ba262a0..a86745a686 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -59,6 +59,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index a9595d5167..832e51de83 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -60,6 +60,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 1ecdba6253..d95b2298fe 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -523,6 +523,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with "