11import torch
22import torch .nn as nn
3- from torch import Tensor
3+ from torch import Tensor , is_grad_enabled
44from functorch import make_fx
55from torch .fx import immutable_collections
66import torch .utils ._pytree as pytree
77import torch .utils .dlpack
88from torch .nn .utils import _stateless
99from functorch ._C import CompileCache
1010from .decompositions import register_decomposition
11- from .partitioners import default_partition
11+ from .partitioners import default_partition , _get_saved_values , _extract_fwd_bwd_modules , _extract_fwd_bwd_modules_db
1212from .named_members_polyfill import _named_parameters , _named_buffers
1313from typing import Callable , List , Dict , Any , Tuple , Optional
1414from functools import wraps
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454
5555def create_joint_forward_backward (fn ):
5656 def joint_forward_backward (
57- primals : List [Any ], tangents : List [Any ]
57+ primals : List [Any ], cotangents : List [Any ]
5858 ) -> Tuple [List [Any ], List [Any ]]:
5959 # Call the forward pass
6060 outs = fn (* primals )
@@ -68,21 +68,21 @@ def joint_forward_backward(
6868 grad_primals .append (p )
6969
7070 # Get the outputs that need gradients
71- assert len (tangents ) == len (outs )
71+ assert len (cotangents ) == len (outs )
7272 needed_outs = []
73- needed_tangents = []
74- for out , tangent in zip (outs , tangents ):
73+ needed_cotangents = []
74+ for out , cotangent in zip (outs , cotangents ):
7575 if isinstance (out , Tensor ) and out .requires_grad :
7676 needed_outs .append (out )
77- needed_tangents .append (tangent )
77+ needed_cotangents .append (cotangent )
7878 backward_out = []
7979 # Call the backwards pass
8080 if grad_primals :
8181 backward_out = torch .autograd .grad (
8282 needed_outs ,
8383 grad_primals ,
84- grad_outputs = needed_tangents ,
85- allow_unused = True ,
84+ grad_outputs = needed_cotangents ,
85+ allow_unused = True
8686 )
8787 backward_out_iter = iter (backward_out )
8888 return outs , [
@@ -138,14 +138,18 @@ def create_aot_autograd_function(
138138 joint_forward_backward = create_joint_forward_backward (flat_fn )
139139
140140 compiled_fw = None
141- compiled_bw = None
141+ bw_modules = []
142+ fw_module = None
142143 num_outs = None
144+ saved_value_names = None
145+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
143146
144147 class CompiledFunction (torch .autograd .Function ):
145148 @staticmethod
146149 @disable_torchdynamo
147150 def forward (ctx , * flat_tensor_args ):
148- nonlocal compiled_fw , compiled_bw , num_outs
151+ ctx .set_materialize_grads (False )
152+ nonlocal compiled_fw , num_outs , saved_value_names , fw_module
149153 if compiled_fw is None :
150154 with torch .set_grad_enabled (grad_state ):
151155 out = flat_fn (* flat_tensor_args )
@@ -159,34 +163,101 @@ def forward(ctx, *flat_tensor_args):
159163 num_outs = 1
160164
161165 joint_inputs = (flat_tensor_args , out )
162- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
166+ # Need it because autograd.Function disables grad in forward
163167 with torch .set_grad_enabled (grad_state ):
164168 fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165169 * joint_inputs
166170 )
167- fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168- # print(fw_module.code, bw_module.code)
169-
171+ # This means the forward and backward graphs are created based on the input fn
172+ # However we need to take in grad_out for the saved intermediates as well.
173+ fw_module , bw_module , saved_value_nodes = partition_fn (fx_g , joint_inputs )
174+ saved_value_names = [node .name for node in saved_value_nodes ]
170175 compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171176 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172-
173- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174- compiled_bw = bw_compiler (bw_module , bw_args )
175177 else :
176178 fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177- ctx .save_for_backward (* fw_outs [num_outs :])
178- return tuple (fw_outs [0 :num_outs ])
179+
180+ # print(fw_module.code)
181+ ctx .num_intermediate = len (fw_outs [num_outs :])
182+ ctx .num_inputs = len (flat_tensor_args )
183+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args ) + fw_outs [0 :num_outs ]
184+ ctx .save_for_backward (* to_be_saved )
185+ return tuple (fw_outs )
179186
180187 @staticmethod
181188 @disable_torchdynamo
182- def backward (ctx , * flat_args ):
183- contiguous_args = [t .contiguous () for t in flat_args ]
184- # contiguous_args = [t for t in flat_args]
185- out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186- return tuple (out )
187-
188- return CompiledFunction
189-
189+ def backward (ctx , * flat_grad_outs ):
190+ nonlocal bw_modules , saved_value_names , fw_module , num_outs
191+ intermediates = ctx .saved_tensors [:ctx .num_intermediate ]
192+ inputs = ctx .saved_tensors [ctx .num_intermediate :ctx .num_intermediate + ctx .num_inputs ]
193+ is_grad_enabled = torch .is_grad_enabled ()
194+ if not is_grad_enabled :
195+ input_flat_grad_outs = []
196+ for grad in flat_grad_outs :
197+ if grad is not None :
198+ input_flat_grad_outs .append (grad )
199+ with torch .set_grad_enabled (grad_state ):
200+ fx_g_b = make_fx (joint_forward_backward , aot_decompositions )(inputs , input_flat_grad_outs )
201+ saved_value_nodes = _get_saved_values (fx_g_b , saved_value_names )
202+ assert len (saved_value_nodes ) <= len (saved_value_names )
203+ fw_module_b , bw_module_b , saved_values_new = _extract_fwd_bwd_modules (fx_g_b , saved_value_nodes )
204+ if len (saved_values_new ) != len (saved_value_names ):
205+ new_intermediates = []
206+ # Forward saves more intermediates than needed
207+ assert len (saved_values_new ) < len (saved_value_names )
208+ j = 0
209+ for node in saved_values_new :
210+ while node .name != saved_value_names [j ]:
211+ j += 1
212+ new_intermediates .append (intermediates [j ])
213+ j += 1
214+ intermediates = new_intermediates
215+ # else:
216+ # input_flat_grad_outs = flat_grad_outs
217+ # # create_joint_forward_backward takes inputs and cotangents as inps
218+ # # inps: inputs, cotangents: flat_grad_outs
219+ # j_b = create_joint_forward_backward(ctx.fw_module)
220+ # # setting grad is not needed
221+ # with torch.set_grad_enabled(grad_state):
222+ # fx_g_b = make_fx(j_b, aot_decompositions)(inputs, input_flat_grad_outs)
223+ # saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
224+ # # print(saved_value_nodes)
225+ # # print(saved_value_names)
226+ # # assert len(saved_value_nodes) == len(saved_value_names)
227+ # fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules_db(fx_g_b, saved_value_nodes)
228+ # # print(fx_g_b.code, ctx.fw_module.code, fw_module_b.code, bw_module_b.code)
229+ # # assert fw_module_b.code == fw_module.code
230+ # # print(len(sew), len(saved_value_names))
231+ # if len(saved_values_new) != len(saved_value_names):
232+ # new_intermediates = []
233+ # # Forward saves more intermediates than needed
234+ # assert len(saved_values_new) < len(saved_value_names)
235+ # for node in saved_values_new:
236+ # j = 0
237+ # while node.name != saved_value_names[j]:
238+ # j+=1
239+ # new_intermediates.append(intermediates[j])
240+ # j+=1
241+ # intermediates = new_intermediates
242+
243+ # This is needed because aot function caching uses function id right now
244+ bw_module_fn = None
245+ for elem in bw_modules :
246+ if elem .code == bw_module_b .code :
247+ bw_module_fn = elem
248+ break
249+ if bw_module_fn is None :
250+ bw_modules .append (bw_module_b )
251+ bw_module_fn = bw_module_b
252+
253+ f = aot_function (bw_module_fn , bw_compiler , bw_compiler , partition_fn , aot_decompositions )
254+ out = f (* intermediates , * input_flat_grad_outs )
255+ return tuple (normalize_as_list (out ))
256+
257+ def return_fn (* args , ** kwargs ):
258+ out = CompiledFunction .apply (* args , ** kwargs )
259+ return out [0 :num_outs ]
260+ return return_fn
190261
191262class _CompileCache (CompileCache ):
192263 pass
@@ -275,7 +346,7 @@ def rearrange(tensor_args, static_args, static_argnums):
275346 return args
276347
277348
278- KNOWN_TYPES = [torch .Tensor , int , str , float , bool ]
349+ KNOWN_TYPES = [torch .Tensor , int , str , float , bool , None ]
279350
280351
281352def aot_function (
@@ -411,7 +482,9 @@ def returned_function(*args, **kwargs):
411482 hasher_type ,
412483 * flat_args_for_cache ,
413484 )
414-
485+ # print("fn_id: ", fn_id)
486+ # print("size: ", compile_cache.size())
487+ # print("num_tensor_args: ", num_tensor_args)
415488 # Compile the function and save it in the cache
416489 if cached_res is None :
417490 # Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -436,7 +509,7 @@ def flat_fn(*flat_tensor_args):
436509 for i in flat_out :
437510 is_known_type = False
438511 for j in KNOWN_TYPES :
439- if isinstance (i , j ):
512+ if j is None or isinstance (i , j ):
440513 is_known_type = True
441514 break
442515 if not is_known_type :
@@ -458,7 +531,7 @@ def flat_fn(*flat_tensor_args):
458531 partition_fn ,
459532 decompositions ,
460533 grad_state = torch .is_grad_enabled (),
461- ). apply
534+ )
462535 cached_res = (compiled_fn , out_spec )
463536
464537 # Save the compiled_fn in the cache
@@ -598,7 +671,7 @@ def aot_function_simplified(
598671 partition_fn ,
599672 decompositions ,
600673 grad_state = torch .is_grad_enabled (),
601- ). apply
674+ )
602675
603676 return compiled_fn
604677
@@ -620,4 +693,4 @@ def forward(self, *args, **kwargs):
620693
621694
622695compiled_function = aot_function
623- compiled_module = aot_module
696+ compiled_module = aot_module
0 commit comments