@@ -497,6 +497,61 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
497497 return (d_input , d_weight , d_bias )
498498
499499
500+ @register_decomposition (aten .native_batch_norm_backward )
501+ def native_batch_norm_backward (grad_out : Tensor , input : Tensor , weight : Optional [Tensor ], running_mean : Optional [Tensor ], running_var : Optional [Tensor ], save_mean : Optional [Tensor ], save_invstd : Optional [Tensor ], train : bool , eps : float , output_mask : List [bool ]) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
502+ input_shape = input .shape
503+ input_rank = input .dim ()
504+ assert input_rank >= 2 , "rank of the input must be at least 2"
505+
506+ axis = 1
507+ num_features = prod (input_shape ) / input_shape [axis ]
508+ mean = save_mean
509+ invstd = save_invstd
510+ if train :
511+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
512+ else :
513+ assert running_mean is not None and running_var is not None
514+ mean = running_mean
515+ invstd = torch .rsqrt (running_var + eps )
516+
517+ broadcast_mask = [1 ] * input_rank
518+ broadcast_mask [axis ] = input_shape [axis ]
519+
520+ reduction_axes = []
521+ for i in range (input_rank ):
522+ if i != axis :
523+ reduction_axes .append (i )
524+
525+ mean = torch .reshape (mean , broadcast_mask )
526+ norm = 1.0 / num_features
527+ grad_output_sum = torch .sum (grad_out , reduction_axes )
528+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
529+
530+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
531+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
532+
533+ grad_scale = None
534+ if weight is None :
535+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
536+ else :
537+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
538+ grad_input = None
539+ if train :
540+ proj = (input - mean ) * proj_scale
541+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
542+ else :
543+ grad_input = grad_out * grad_scale
544+
545+ grad_weight = None
546+ if output_mask [1 ]:
547+ grad_weight = dot_p * invstd
548+
549+ grad_bias = None
550+ if output_mask [2 ]:
551+ grad_bias = grad_output_sum
552+ return (grad_input , grad_weight , grad_bias )
553+
554+
500555@register_decomposition (aten .clamp_min )
501556def clamp_min (self : Tensor , min : float ):
502557 return torch .clamp (self , min = min )
0 commit comments