Skip to content

Commit

Permalink
Move augmented forward trace post processing of output to forward_and…
Browse files Browse the repository at this point in the history
…_backward_from_trace
  • Loading branch information
IvanYashchuk committed Oct 18, 2024
1 parent f82300c commit 03349ac
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
unvariableify,
FutureTensorProxy,
)
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.compile_data import get_compile_data, get_compile_option, get_compile_stats
from thunder.core.langctxs import langctx, Languages
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol
Expand Down Expand Up @@ -2971,7 +2971,7 @@ def unpacking_fn(saved_for_backward, cotangents):
# )


def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces:
def forward_and_backward_from_trace(trace: Trace, torch_autograd=False, requires_grad_mask=None) -> ForwardBackwardTraces:
"""Generates the forward and backward passes from a trace.
This is a convenience function that combines the functionality of
Expand Down Expand Up @@ -3024,6 +3024,7 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa
... t2 = prims.mul(cotangents, t1) # t2: "cpu f32[]"
... return (t2,))
"""
from thunder.transforms.torch_autograd import connect_to_torch_autograd

# forward_trace, result, env = augmented_forward_pass_trace(trace, *trace.args, **trace.kwargs)
# forward_trace.tags.add(TraceTag.AUGMENTED_FORWARD)
Expand Down Expand Up @@ -3114,6 +3115,18 @@ def backward_fn(saved_for_backward, cotangents):
backward_fn, saved_for_backward, cotangents
)

if torch_autograd and requires_grad_mask is not None:
# Update the backward trace to only compute gradients for the
# inputs that require gradients
assert backward_trace.bound_symbols[-1].sym.id == prims.PrimIDs.RETURN
filtered_grads = tuple(
(arg_grad if requires_grad else None)
for arg_grad, requires_grad in utils.safe_zip(backward_trace.bound_symbols[-1].args[0], requires_grad_mask)
)

# autograd.Function.backward expects a flat tuple of gradients
backward_trace.bound_symbols[-1] = replace(backward_trace.bound_symbols[-1], args=(filtered_grads,))

# We are done with constructing the forward and backward passes at this
# stage. The following is not strictly necessary, but it's good to filter
# out the unused elements of the saved_for_backward and flatten it for more
Expand Down Expand Up @@ -3159,6 +3172,34 @@ def backward_fn(saved_for_backward, cotangents):
if enable_saved_for_backward_recomputation:
forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace)

if torch_autograd:
# Now let's update the forward trace to use the
# connect_to_torch_autograd that will store the backward trace and other
# information needed to connect to PyTorch Autograd
data_for_autograd, (saved_tensors, saved_other) = forward_trace.output

with tracectx(forward_trace):
forward_trace.scopes = [forward_trace.bound_symbols]
return_bsym = forward_trace.bound_symbols.pop()
assert return_bsym.sym.id == prims.PrimIDs.RETURN
new_flat_output = connect_to_torch_autograd(
backward=backward_trace,
return_none_instead_of_grads=get_compile_data().return_none_instead_of_grads,
saved_tensors=saved_tensors,
saved_other=saved_other,
flat_args=data_for_autograd["flat_args"],
flat_output=data_for_autograd["flat_tensor_output"],
)
old_to_new_output = {
variableify(old): new
for old, new in utils.safe_zip(data_for_autograd["flat_tensor_output"], new_flat_output)
}
new_output = tree_map(lambda out: old_to_new_output.get(variableify(out), out), data_for_autograd["output"])
prims.python_return({"output": new_output, "flat_args": data_for_autograd["flat_args"]})

get_compile_stats().last_backward_traces.append(backward_trace)
return forward_trace

return ForwardBackwardTraces(forward_trace, backward_trace)


Expand Down

0 comments on commit 03349ac

Please sign in to comment.