Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
accumulate=(
accumulate_wgrad_into_param_main_grad
if not hasattr(weights[0], "__fsdp_param__")
else False
),
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,11 @@ def backward(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not hasattr(weight, "__fsdp_param__")
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
Expand Down
12 changes: 10 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,11 @@ def backward(
else ctx.activation_dtype
),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not hasattr(fc1_weight, "__fsdp_param__")
else False
),
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
Expand Down Expand Up @@ -1163,7 +1167,11 @@ def fc2_wgrad_gemm(
else ctx.activation_dtype
),
"quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not hasattr(fc2_weight, "__fsdp_param__")
else False
),
"layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
Expand Down
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not hasattr(weight, "__fsdp_param__")
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ def op_backward(
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = False
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def fuser_backward(
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = False
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def fuser_backward(
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = False
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def fuser_backward(
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = False
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down