@@ -47,6 +47,18 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
4747 return min - torch .log1p (z ), buffer
4848
4949
50+ def recompute_mean_var (input : Tensor , rstd : Tensor , inner_dim_indices : List [int ], keepdim : bool ):
51+ # for most norm decompositions, it will be the same as the core version except for here.
52+ # We recompute the mean and variance so that they track gradients through input
53+
54+ mean = torch .mean (input , dim = inner_dim_indices , keepdim = keepdim )
55+ var = torch .var (input , dim = inner_dim_indices , unbiased = False , keepdim = keepdim )
56+ eps = torch .pow (1 / rstd , 2 ) - var # this makes me so sad inside
57+ eps = eps .detach ()
58+ rstd = 1 / torch .sqrt (var + eps )
59+ return mean , rstd
60+
61+
5062@register_decomposition_for_jvp (aten .native_layer_norm_backward )
5163def native_layer_norm_backward (
5264 grad_out : Tensor ,
@@ -80,13 +92,7 @@ def native_layer_norm_backward(
8092 input .new_zeros (input_shape [axis :]),
8193 )
8294
83- # this is exactly the same as the other decomposition except for here. We recompute the mean and variance
84- # so that they track gradients through input
85- mean_ = torch .mean (input , dim = inner_dim_indices , keepdim = True )
86- var = torch .var (input , dim = inner_dim_indices , unbiased = False , keepdim = True )
87- eps = torch .pow (1 / rstd , 2 ) - var # this makes me so sad inside
88- eps = eps .detach ()
89- rstd_ = 1 / torch .sqrt (var + eps )
95+ mean_ , rstd_ = recompute_mean_var (input , rstd , inner_dim_indices , keepdim = True )
9096
9197 x_hat = (input - mean_ ) * rstd_
9298 if weight is not None :
@@ -128,3 +134,84 @@ def native_layer_norm_backward(
128134 d_bias = torch .zeros (()) # should be None but doesn't work with vjp
129135
130136 return (d_input , d_weight , d_bias )
137+
138+
139+ def prod (x : List [int ]):
140+ r = 1
141+ for i in x :
142+ r *= i
143+ return r
144+
145+
146+ @register_decomposition (aten .native_batch_norm_backward ) # @register_decomposition_for_jvp after in core
147+ def native_batch_norm_backward (
148+ grad_out : Tensor ,
149+ input : Tensor ,
150+ weight : Optional [Tensor ],
151+ running_mean : Optional [Tensor ],
152+ running_var : Optional [Tensor ],
153+ save_mean : Optional [Tensor ],
154+ save_invstd : Optional [Tensor ],
155+ train : bool ,
156+ eps : float ,
157+ output_mask : List [bool ],
158+ ) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
159+ input_shape = input .shape
160+ input_rank = input .dim ()
161+ assert input_rank >= 2 , "rank of the input must be at least 2"
162+
163+ axis = 1
164+ num_features = prod (input_shape ) / input_shape [axis ]
165+ mean = save_mean
166+ invstd = save_invstd
167+ if train :
168+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
169+
170+ reduciton_dims = [0 ] + list (range (2 , input .dim ()))
171+ assert invstd is not None # for typing
172+ mean , invstd = recompute_mean_var (input , invstd , reduciton_dims , keepdim = False )
173+ else :
174+ assert running_mean is not None and running_var is not None
175+ mean = running_mean
176+ invstd = torch .rsqrt (running_var + eps )
177+
178+ broadcast_mask = [1 ] * input_rank
179+ broadcast_mask [axis ] = input_shape [axis ]
180+
181+ reduction_axes : List [int ] = []
182+ for i in range (input_rank ):
183+ if i != axis :
184+ reduction_axes .append (i )
185+
186+ mean = torch .reshape (mean , broadcast_mask )
187+ norm = 1.0 / num_features
188+ grad_output_sum = torch .sum (grad_out , reduction_axes )
189+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
190+
191+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
192+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
193+
194+ if weight is None :
195+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
196+ else :
197+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
198+
199+ if train :
200+ proj = (input - mean ) * proj_scale
201+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
202+ else :
203+ grad_input = grad_out * grad_scale
204+
205+ if output_mask [1 ]:
206+ grad_weight = dot_p * invstd
207+ elif weight is not None :
208+ grad_weight = torch .zeros_like (weight ) # should be None but doesn't work with vjp
209+ else :
210+ grad_weight = torch .zeros (()) # should be None but doesn't work with vjp
211+
212+ if output_mask [2 ]:
213+ grad_bias = grad_output_sum
214+ else :
215+ grad_bias = torch .zeros_like (grad_output_sum ) # should be None but doesn't work with vjp
216+
217+ return (grad_input , grad_weight , grad_bias )
0 commit comments