@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454
5555def create_joint_forward_backward (fn ):
5656 def joint_forward_backward (
57- primals : List [Any ], tangents : List [Any ]
57+ primals : List [Any ], cotangents : List [Any ]
5858 ) -> Tuple [List [Any ], List [Any ]]:
5959 # Call the forward pass
6060 outs = fn (* primals )
@@ -68,20 +68,20 @@ def joint_forward_backward(
6868 grad_primals .append (p )
6969
7070 # Get the outputs that need gradients
71- assert len (tangents ) == len (outs )
71+ assert len (cotangents ) == len (outs )
7272 needed_outs = []
73- needed_tangents = []
74- for out , tangent in zip (outs , tangents ):
73+ needed_cotangents = []
74+ for out , cotangent in zip (outs , cotangents ):
7575 if isinstance (out , Tensor ) and out .requires_grad :
7676 needed_outs .append (out )
77- needed_tangents .append (tangent )
77+ needed_cotangents .append (cotangent )
7878 backward_out = []
7979 # Call the backwards pass
8080 if grad_primals :
8181 backward_out = torch .autograd .grad (
8282 needed_outs ,
8383 grad_primals ,
84- grad_outputs = needed_tangents ,
84+ grad_outputs = needed_cotangents ,
8585 allow_unused = True ,
8686 )
8787 backward_out_iter = iter (backward_out )
@@ -140,12 +140,14 @@ def create_aot_autograd_function(
140140 compiled_fw = None
141141 compiled_bw = None
142142 num_outs = None
143-
143+ joint_inputs = None
144+ fw_outs = None
145+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
144146 class CompiledFunction (torch .autograd .Function ):
145147 @staticmethod
146148 @disable_torchdynamo
147149 def forward (ctx , * flat_tensor_args ):
148- nonlocal compiled_fw , compiled_bw , num_outs
150+ nonlocal compiled_fw , num_outs , joint_inputs , fw_outs
149151 if compiled_fw is None :
150152 with torch .set_grad_enabled (grad_state ):
151153 out = flat_fn (* flat_tensor_args )
@@ -159,29 +161,34 @@ def forward(ctx, *flat_tensor_args):
159161 num_outs = 1
160162
161163 joint_inputs = (flat_tensor_args , out )
162- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
164+ # Need it because autograd.Function disables grad in forward
163165 with torch .set_grad_enabled (grad_state ):
164166 fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165167 * joint_inputs
166168 )
167169 fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168- # print(fw_module.code, bw_module.code)
169170
170171 compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171172 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172-
173- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174- compiled_bw = bw_compiler (bw_module , bw_args )
173+ if partition_fn is default_partition :
174+ nonlocal compiled_bw
175+ bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
176+ compiled_bw = bw_compiler (bw_module , bw_args )
175177 else :
176178 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177179 ctx .save_for_backward (* fw_outs [num_outs :])
178180 return tuple (fw_outs [0 :num_outs ])
179181
180182 @staticmethod
181183 @disable_torchdynamo
182- def backward (ctx , * flat_args ):
183- contiguous_args = [t .contiguous () for t in flat_args ]
184- # contiguous_args = [t for t in flat_args]
184+ def backward (ctx , * flat_grad_outs ):
185+ nonlocal compiled_bw
186+ contiguous_args = [t .contiguous () for t in flat_grad_outs ]
187+ if compiled_bw is None :
188+ with torch .set_grad_enabled (grad_state ):
189+ fx_g = make_fx (joint_forward_backward , aot_decompositions )(joint_inputs [0 ], contiguous_args )
190+ fw_module , bw_module = partition_fn (fx_g , joint_inputs )
191+ compiled_bw = bw_compiler (bw_module , fw_outs [num_outs :] + contiguous_args )
185192 out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186193 return tuple (out )
187194
0 commit comments