diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index ddb6cab3f0..811e486145 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -34,7 +34,7 @@ unvariableify, FutureTensorProxy, ) -from thunder.core.compile_data import get_compile_data, get_compile_option +from thunder.core.compile_data import get_compile_data, get_compile_option, get_compile_stats from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol @@ -2971,7 +2971,7 @@ def unpacking_fn(saved_for_backward, cotangents): # ) -def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces: +def forward_and_backward_from_trace(trace: Trace, torch_autograd=False, requires_grad_mask=None) -> ForwardBackwardTraces: """Generates the forward and backward passes from a trace. This is a convenience function that combines the functionality of @@ -3024,6 +3024,7 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa ... t2 = prims.mul(cotangents, t1) # t2: "cpu f32[]" ... return (t2,)) """ + from thunder.transforms.torch_autograd import connect_to_torch_autograd # forward_trace, result, env = augmented_forward_pass_trace(trace, *trace.args, **trace.kwargs) # forward_trace.tags.add(TraceTag.AUGMENTED_FORWARD) @@ -3114,6 +3115,18 @@ def backward_fn(saved_for_backward, cotangents): backward_fn, saved_for_backward, cotangents ) + if torch_autograd and requires_grad_mask is not None: + # Update the backward trace to only compute gradients for the + # inputs that require gradients + assert backward_trace.bound_symbols[-1].sym.id == prims.PrimIDs.RETURN + filtered_grads = tuple( + (arg_grad if requires_grad else None) + for arg_grad, requires_grad in utils.safe_zip(backward_trace.bound_symbols[-1].args[0], requires_grad_mask) + ) + + # autograd.Function.backward expects a flat tuple of gradients + backward_trace.bound_symbols[-1] = replace(backward_trace.bound_symbols[-1], args=(filtered_grads,)) + # We are done with constructing the forward and backward passes at this # stage. The following is not strictly necessary, but it's good to filter # out the unused elements of the saved_for_backward and flatten it for more @@ -3159,6 +3172,34 @@ def backward_fn(saved_for_backward, cotangents): if enable_saved_for_backward_recomputation: forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace) + if torch_autograd: + # Now let's update the forward trace to use the + # connect_to_torch_autograd that will store the backward trace and other + # information needed to connect to PyTorch Autograd + data_for_autograd, (saved_tensors, saved_other) = forward_trace.output + + with tracectx(forward_trace): + forward_trace.scopes = [forward_trace.bound_symbols] + return_bsym = forward_trace.bound_symbols.pop() + assert return_bsym.sym.id == prims.PrimIDs.RETURN + new_flat_output = connect_to_torch_autograd( + backward=backward_trace, + return_none_instead_of_grads=get_compile_data().return_none_instead_of_grads, + saved_tensors=saved_tensors, + saved_other=saved_other, + flat_args=data_for_autograd["flat_args"], + flat_output=data_for_autograd["flat_tensor_output"], + ) + old_to_new_output = { + variableify(old): new + for old, new in utils.safe_zip(data_for_autograd["flat_tensor_output"], new_flat_output) + } + new_output = tree_map(lambda out: old_to_new_output.get(variableify(out), out), data_for_autograd["output"]) + prims.python_return({"output": new_output, "flat_args": data_for_autograd["flat_args"]}) + + get_compile_stats().last_backward_traces.append(backward_trace) + return forward_trace + return ForwardBackwardTraces(forward_trace, backward_trace)