Skip to content

Commit

Permalink
deduplicate saved_for_backward
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 3, 2024
1 parent 7869f83 commit 24d6a8d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
1 change: 0 additions & 1 deletion thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def __init__(self, trace, *args, **kwargs):
self.trace = trace
self.new_trace = from_trace(self.trace)
self.have_processed_args = False
print(self.trace)

def read(self, x: VariableInterface | Any) -> Any:
if isinstance(x, VariableInterface):
Expand Down
21 changes: 20 additions & 1 deletion thunder/core/vjp_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import wraps
from inspect import Parameter, Signature
from itertools import chain
Expand Down Expand Up @@ -229,3 +229,22 @@ def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]:
lambda: "All saved tensors must be TensorProxy or None",
)
return tuple(saved_tensors)


def set_saved_for_backward_tensors(trace: TraceCtx, saved_tensors: Sequence[TensorProxy]):
"""
Given a trace, return the tensors that are saved for backward in the trace.
Args:
trace: The trace to set saved tensors for.
saved_tensors: proxies for the tensors to save.
"""
utils.check(
all(isinstance(t, TensorProxy) or t is None for t in saved_tensors),
lambda: "All saved tensors must be TensorProxy or None",
)
ret_node = trace.bound_symbols.pop(-1)
assert ret_node.sym == prims.python_return
output = ret_node.args
output = (output[0], (tuple(saved_tensors), *output[1][1:]), *output[2:])
trace.bound_symbols.append(ret_node.from_bsym(args=output))
19 changes: 18 additions & 1 deletion thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from thunder.core.symbol import BoundSymbol
from thunder.core.trace import TraceCtx, from_trace, set_tracectx, reset_tracectx
from thunder.core.transform_common import replace_redundant_inputs
from thunder.core.vjp_utils import get_saved_for_backward_tensors
from thunder.core.vjp_utils import get_saved_for_backward_tensors, set_saved_for_backward_tensors

if TYPE_CHECKING:
from thunder.core.trace import VariableInterface
Expand Down Expand Up @@ -240,6 +240,23 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
skip_output=False,
skip_subsymbols=False,
)

# remove duplicates
# The NVFuser (and possibly others) fusion pass applied on the forward during has a
# CSE pass that may lead to duplicate symbols saved for backward. This causes trouble
# because we see duplicates in the unpacking. But the passes are unaware of the backward,
# so they cannot handle it themselves, so we clean this up here.
seen = set()
new_fw_out = []
new_bw_inp = []
for p_fw, p_bw in zip(get_saved_for_backward_tensors(fw_extrace), new_bsyms[4].output, strict=True):
if p_fw.name not in seen:
seen.add(p_fw.name)
new_fw_out.append(p_fw)
new_bw_inp.append(p_bw)
new_bsyms[4] = new_bsyms[4].from_bsym(output=tuple(new_bw_inp))
set_saved_for_backward_tensors(fw_extrace, new_fw_out)

bw_trace.bound_symbols = new_bsyms

if getattr(compile_data.fn, "use_fsdp", False):
Expand Down

0 comments on commit 24d6a8d

Please sign in to comment.