@@ -96,6 +96,7 @@ def backward(ctx, grad_output):
96
96
use_zbv = ctx .use_zbv
97
97
98
98
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
99
+ weight_origin = weight
99
100
weight = weight .view (weight .shape )
100
101
if bias is not None :
101
102
bias = bias .view (bias .shape )
@@ -130,7 +131,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
130
131
WeightGradStore .put (
131
132
total_input ,
132
133
grad_output ,
133
- weight ,
134
+ ( weight , weight_origin ) ,
134
135
functools .partial (
135
136
execute_w_pass_grad_accum ,
136
137
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 ,
@@ -141,7 +142,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
141
142
WeightGradStore .put (
142
143
total_input ,
143
144
grad_output ,
144
- weight ,
145
+ ( weight , weight_origin ) ,
145
146
functools .partial (
146
147
execute_w_pass_grad_accum ,
147
148
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16 ,
@@ -164,7 +165,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
164
165
WeightGradStore .put (
165
166
total_input ,
166
167
grad_output ,
167
- weight ,
168
+ ( weight , weight_origin ) ,
168
169
functools .partial (
169
170
execute_w_pass ,
170
171
wgrad_gemm_func = torch .matmul ,
@@ -212,6 +213,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
212
213
return wgrad_gemm_func (_input_ .t (), _grad_output_ )
213
214
214
215
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
216
+ weight_origin = weight
215
217
weight = weight .view (weight .shape )
216
218
if bias is not None :
217
219
bias = bias .view (bias .shape )
@@ -232,7 +234,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
232
234
WeightGradStore .put (
233
235
total_input ,
234
236
grad_output ,
235
- weight ,
237
+ ( weight , weight_origin ) ,
236
238
functools .partial (
237
239
execute_w_pass_grad_accum ,
238
240
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 ,
@@ -243,7 +245,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
243
245
WeightGradStore .put (
244
246
total_input ,
245
247
grad_output ,
246
- weight ,
248
+ ( weight , weight_origin ) ,
247
249
functools .partial (
248
250
execute_w_pass_grad_accum ,
249
251
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16 ,
@@ -266,7 +268,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
266
268
WeightGradStore .put (
267
269
total_input ,
268
270
grad_output ,
269
- weight ,
271
+ ( weight , weight_origin ) ,
270
272
functools .partial (
271
273
execute_w_pass ,
272
274
wgrad_gemm_func = torch .matmul ,
@@ -1026,6 +1028,7 @@ def backward(ctx, grad_output):
1026
1028
use_zbv = ctx .use_zbv
1027
1029
1028
1030
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
1031
+ weight_origin = weight
1029
1032
weight = weight .view (weight .shape )
1030
1033
if use_bias :
1031
1034
bias = bias .view (bias .shape )
@@ -1064,7 +1067,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1064
1067
WeightGradStore .put (
1065
1068
total_input ,
1066
1069
grad_output ,
1067
- weight ,
1070
+ ( weight , weight_origin ) ,
1068
1071
functools .partial (
1069
1072
execute_w_pass_grad_accum ,
1070
1073
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 ,
@@ -1075,7 +1078,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1075
1078
WeightGradStore .put (
1076
1079
total_input ,
1077
1080
grad_output ,
1078
- weight ,
1081
+ ( weight , weight_origin ) ,
1079
1082
functools .partial (
1080
1083
execute_w_pass_grad_accum ,
1081
1084
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16 ,
@@ -1098,7 +1101,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
1098
1101
WeightGradStore .put (
1099
1102
total_input ,
1100
1103
grad_output ,
1101
- weight ,
1104
+ ( weight , weight_origin ) ,
1102
1105
functools .partial (
1103
1106
execute_w_pass ,
1104
1107
wgrad_gemm_func = torch .matmul ,
0 commit comments