Skip to content

Commit ff316c9

Browse files
committed
[fix] fix weight grad none, err caused by weight ptr change
1 parent f0a8d78 commit ff316c9

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

colossalai/pipeline/weight_grad_store.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,23 @@ def pop(cls, chunk=0):
2020
if cls.weight_grad_queue[chunk].qsize() > 0:
2121
stored_grads = cls.weight_grad_queue[chunk].get()
2222
for total_input, grad_output, weight, func in stored_grads:
23-
if weight.grad is not None:
24-
func(total_input, grad_output, weight.grad)
25-
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
23+
if isinstance(weight, tuple):
24+
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
25+
# View will lead to weight ptr change
26+
# weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update
27+
weight_cal, weight_origin = weight
28+
if weight_origin.grad is not None:
29+
func(total_input, grad_output, weight_origin)
30+
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
31+
else:
32+
grad_weight = func(total_input, grad_output)
33+
weight_origin.grad = grad_weight
2634
else:
27-
grad_weight = func(total_input, grad_output)
28-
weight.grad = grad_weight
35+
if weight.grad is not None:
36+
func(total_input, grad_output, weight.grad)
37+
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
38+
else:
39+
grad_weight = func(total_input, grad_output)
40+
weight.grad = grad_weight
2941
else:
3042
raise Exception("Pop empty queue.")

colossalai/shardformer/layer/_operation.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def backward(ctx, grad_output):
9696
use_zbv = ctx.use_zbv
9797

9898
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
99+
weight_origin = weight
99100
weight = weight.view(weight.shape)
100101
if bias is not None:
101102
bias = bias.view(bias.shape)
@@ -130,7 +131,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
130131
WeightGradStore.put(
131132
total_input,
132133
grad_output,
133-
weight,
134+
(weight, weight_origin),
134135
functools.partial(
135136
execute_w_pass_grad_accum,
136137
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
@@ -141,7 +142,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
141142
WeightGradStore.put(
142143
total_input,
143144
grad_output,
144-
weight,
145+
(weight, weight_origin),
145146
functools.partial(
146147
execute_w_pass_grad_accum,
147148
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
@@ -164,7 +165,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
164165
WeightGradStore.put(
165166
total_input,
166167
grad_output,
167-
weight,
168+
(weight, weight_origin),
168169
functools.partial(
169170
execute_w_pass,
170171
wgrad_gemm_func=torch.matmul,
@@ -212,6 +213,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
212213
return wgrad_gemm_func(_input_.t(), _grad_output_)
213214

214215
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
216+
weight_origin = weight
215217
weight = weight.view(weight.shape)
216218
if bias is not None:
217219
bias = bias.view(bias.shape)
@@ -232,7 +234,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
232234
WeightGradStore.put(
233235
total_input,
234236
grad_output,
235-
weight,
237+
(weight, weight_origin),
236238
functools.partial(
237239
execute_w_pass_grad_accum,
238240
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
@@ -243,7 +245,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
243245
WeightGradStore.put(
244246
total_input,
245247
grad_output,
246-
weight,
248+
(weight, weight_origin),
247249
functools.partial(
248250
execute_w_pass_grad_accum,
249251
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
@@ -266,7 +268,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
266268
WeightGradStore.put(
267269
total_input,
268270
grad_output,
269-
weight,
271+
(weight, weight_origin),
270272
functools.partial(
271273
execute_w_pass,
272274
wgrad_gemm_func=torch.matmul,
@@ -1026,6 +1028,7 @@ def backward(ctx, grad_output):
10261028
use_zbv = ctx.use_zbv
10271029

10281030
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
1031+
weight_origin = weight
10291032
weight = weight.view(weight.shape)
10301033
if use_bias:
10311034
bias = bias.view(bias.shape)
@@ -1064,7 +1067,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
10641067
WeightGradStore.put(
10651068
total_input,
10661069
grad_output,
1067-
weight,
1070+
(weight, weight_origin),
10681071
functools.partial(
10691072
execute_w_pass_grad_accum,
10701073
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
@@ -1075,7 +1078,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
10751078
WeightGradStore.put(
10761079
total_input,
10771080
grad_output,
1078-
weight,
1081+
(weight, weight_origin),
10791082
functools.partial(
10801083
execute_w_pass_grad_accum,
10811084
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
@@ -1098,7 +1101,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
10981101
WeightGradStore.put(
10991102
total_input,
11001103
grad_output,
1101-
weight,
1104+
(weight, weight_origin),
11021105
functools.partial(
11031106
execute_w_pass,
11041107
wgrad_gemm_func=torch.matmul,

tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,10 @@ def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mo
185185

186186
# check the input gradients & weight gradients
187187
assert_close(out.grad, gather_out.grad)
188-
# TODO:linear_base.weight.grad is None; But not none in WeightGradStore
189-
# assert_close(linear.weight.grad, linear_base.weight.grad)
188+
assert_close(linear.weight.grad, linear_base.weight.grad)
190189

191190

192-
@parameterize("lazy_init", [False, True])
191+
@parameterize("lazy_init", [False])
193192
@parameterize("seq_parallel_mode", ["split_gather", None])
194193
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
195194
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)

0 commit comments

Comments
 (0)