From 24d6a8d149d92fcb0d7280e2e61c06b9c60454b0 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 10:46:36 +0100 Subject: [PATCH] deduplicate saved_for_backward --- thunder/core/trace_interpreter.py | 1 - thunder/core/vjp_utils.py | 21 ++++++++++++++++++++- thunder/executors/torch_autograd.py | 19 ++++++++++++++++++- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 37694ade6b..b661eca3db 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -213,7 +213,6 @@ def __init__(self, trace, *args, **kwargs): self.trace = trace self.new_trace = from_trace(self.trace) self.have_processed_args = False - print(self.trace) def read(self, x: VariableInterface | Any) -> Any: if isinstance(x, VariableInterface): diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 0921bf7c6e..b6f3712e5e 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -1,5 +1,5 @@ import inspect -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import wraps from inspect import Parameter, Signature from itertools import chain @@ -229,3 +229,22 @@ def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]: lambda: "All saved tensors must be TensorProxy or None", ) return tuple(saved_tensors) + + +def set_saved_for_backward_tensors(trace: TraceCtx, saved_tensors: Sequence[TensorProxy]): + """ + Given a trace, return the tensors that are saved for backward in the trace. + + Args: + trace: The trace to set saved tensors for. + saved_tensors: proxies for the tensors to save. + """ + utils.check( + all(isinstance(t, TensorProxy) or t is None for t in saved_tensors), + lambda: "All saved tensors must be TensorProxy or None", + ) + ret_node = trace.bound_symbols.pop(-1) + assert ret_node.sym == prims.python_return + output = ret_node.args + output = (output[0], (tuple(saved_tensors), *output[1][1:]), *output[2:]) + trace.bound_symbols.append(ret_node.from_bsym(args=output)) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 5374b23afe..ce9497125b 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -10,7 +10,7 @@ from thunder.core.symbol import BoundSymbol from thunder.core.trace import TraceCtx, from_trace, set_tracectx, reset_tracectx from thunder.core.transform_common import replace_redundant_inputs -from thunder.core.vjp_utils import get_saved_for_backward_tensors +from thunder.core.vjp_utils import get_saved_for_backward_tensors, set_saved_for_backward_tensors if TYPE_CHECKING: from thunder.core.trace import VariableInterface @@ -240,6 +240,23 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat skip_output=False, skip_subsymbols=False, ) + + # remove duplicates + # The NVFuser (and possibly others) fusion pass applied on the forward during has a + # CSE pass that may lead to duplicate symbols saved for backward. This causes trouble + # because we see duplicates in the unpacking. But the passes are unaware of the backward, + # so they cannot handle it themselves, so we clean this up here. + seen = set() + new_fw_out = [] + new_bw_inp = [] + for p_fw, p_bw in zip(get_saved_for_backward_tensors(fw_extrace), new_bsyms[4].output, strict=True): + if p_fw.name not in seen: + seen.add(p_fw.name) + new_fw_out.append(p_fw) + new_bw_inp.append(p_bw) + new_bsyms[4] = new_bsyms[4].from_bsym(output=tuple(new_bw_inp)) + set_saved_for_backward_tensors(fw_extrace, new_fw_out) + bw_trace.bound_symbols = new_bsyms if getattr(compile_data.fn, "use_fsdp", False):