Commit aed7c7e
committed
Update on "[WIP] Compute forward grads for saved_mean and saved_var when input requires grad"
We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.
Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`
Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch
[ghstack-poisoned]1 file changed
+8
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | | - | |
26 | | - | |
27 | | - | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
28 | 28 | | |
29 | | - | |
30 | | - | |
| 29 | + | |
| 30 | + | |
31 | 31 | | |
32 | | - | |
33 | | - | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
| |||
0 commit comments