@@ -53,6 +53,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5353
5454
5555def create_joint_forward_backward (fn ):
56+ # tangents are just grad_outs/cotangents (wrong naming)
5657 def joint_forward_backward (
5758 primals : List [Any ], tangents : List [Any ]
5859 ) -> Tuple [List [Any ], List [Any ]]:
@@ -140,12 +141,14 @@ def create_aot_autograd_function(
140141 compiled_fw = None
141142 compiled_bw = None
142143 num_outs = None
143-
144+ joint_inputs = None
145+ fw_outs = None
146+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
144147 class CompiledFunction (torch .autograd .Function ):
145148 @staticmethod
146149 @disable_torchdynamo
147150 def forward (ctx , * flat_tensor_args ):
148- nonlocal compiled_fw , compiled_bw , num_outs
151+ nonlocal compiled_fw , num_outs , joint_inputs , fw_outs
149152 if compiled_fw is None :
150153 with torch .set_grad_enabled (grad_state ):
151154 out = flat_fn (* flat_tensor_args )
@@ -159,29 +162,34 @@ def forward(ctx, *flat_tensor_args):
159162 num_outs = 1
160163
161164 joint_inputs = (flat_tensor_args , out )
162- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
165+ # Need it because autograd.Function disables grad in forward
163166 with torch .set_grad_enabled (grad_state ):
164167 fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165168 * joint_inputs
166169 )
167170 fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168- # print(fw_module.code, bw_module.code)
169171
170172 compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171173 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172-
173- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174- compiled_bw = bw_compiler (bw_module , bw_args )
174+ if partition_fn is default_partition :
175+ nonlocal compiled_bw
176+ bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
177+ compiled_bw = bw_compiler (bw_module , bw_args )
175178 else :
176179 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177180 ctx .save_for_backward (* fw_outs [num_outs :])
178181 return tuple (fw_outs [0 :num_outs ])
179182
180183 @staticmethod
181184 @disable_torchdynamo
182- def backward (ctx , * flat_args ):
183- contiguous_args = [t .contiguous () for t in flat_args ]
184- # contiguous_args = [t for t in flat_args]
185+ def backward (ctx , * flat_grad_outs ):
186+ nonlocal compiled_bw
187+ contiguous_args = [t .contiguous () for t in flat_grad_outs ]
188+ if compiled_bw is None :
189+ with torch .set_grad_enabled (grad_state ):
190+ fx_g = make_fx (joint_forward_backward , aot_decompositions )(joint_inputs [0 ], contiguous_args )
191+ fw_module , bw_module = partition_fn (fx_g , joint_inputs )
192+ compiled_bw = bw_compiler (bw_module , fw_outs [num_outs :] + contiguous_args )
185193 out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186194 return tuple (out )
187195
0 commit comments