From 754d86e594dc5a95efcd6bb0d0ddf6335730c80f Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 10:11:47 +0100 Subject: [PATCH 01/17] update checkpointing support for jit --- thunder/core/jit_ext.py | 52 ++++++++++++++++++++++++++++- thunder/core/rematerialization.py | 3 ++ thunder/core/symbol.py | 9 ++++- thunder/core/transforms.py | 39 +++++++++++++++++----- thunder/executors/torch_autograd.py | 2 +- thunder/executors/torch_compile.py | 2 +- thunder/torch/__init__.py | 2 -- 7 files changed, 95 insertions(+), 14 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 726353983a..18de748200 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -895,11 +895,18 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype): return res +ProxyTag.register_tag("RECOMPUTE_IN_BACKWARD") + + @register_general_jit_lookaside(torch.utils.checkpoint.checkpoint) def _general_jit_torch_checkpoint_lookaside( function: Callable, *args, - **kwargs: Any, + context_fn: None | Callable[..., Any] = None, + debug: None | bool = None, + determinism_check: None | str = None, + preserve_rng_state: None | bool = None, + use_reentrant: bool = False, ): """ This function does preprocessing of the `function` argument before @@ -917,6 +924,48 @@ def _general_jit_torch_checkpoint_lookaside( The result of calling `thunder.torch.checkpoint` with the preprocessed `function` and its arguments. """ + + if unwrap(use_reentrant): + return do_raise( + "torch.checkpoint: use_reentrant=True is not supported in Thunder", + ) + # NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments + # Let's raise a warning if any of these arguments are passed + if unwrap(context_fn) is not None: + warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored") + if unwrap(debug) is not None: + warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored") + if unwrap(determinism_check) is not None: + warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored") + if unwrap(preserve_rng_state) is not None: + warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored") + + jit_ctx: JitCtx = get_jit_ctx() + jit_ctx.computation_trace.push_scope([]) + + input_output_proxy_names = set() + + def add_input_output_proxy_name(p): + if isinstance(p, Proxy): + input_output_proxy_names.add(p.name) + + tree_map(add_input_output_proxy_name, [unwrap(a) for a in args]) + + res = _interpret_call(function, *args) + if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return res + + tree_map(add_input_output_proxy_name, unwrap(res)) + + new_bsyms = jit_ctx.computation_trace.pop_scope() + jit_ctx.computation_trace.bound_symbols.extend(new_bsyms) + + for bsym in new_bsyms: + for o in bsym.flat_proxy_outs: + if o.name not in input_output_proxy_names: + o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) + + return res from thunder.torch import checkpoint # It should be possible to call the general_thunder_jit here to handle the @@ -926,6 +975,7 @@ def _general_jit_torch_checkpoint_lookaside( def thunder_function(*args, **kwargs): return unwrap(function)(*args, **kwargs) + XXX wrapped_thunder_function = wrap_const(thunder_function) return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index f24ee1fad2..ee9b8f4464 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -361,6 +361,7 @@ def add_edges(var): if not required_producer_vars: # If there are no required producer variables, we need to make sure that # the source node is added to the graph. + print("#### hello") add_edge("source", "source", capacity=float("inf")) for var in required_producer_vars: @@ -374,6 +375,8 @@ def add_edges(var): g = nx.DiGraph() g.add_edges_from(edges) + print("#####", dict(g.nodes), g.edges) + try: _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") except Exception: diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 2906241e11..74c488a479 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -16,7 +16,7 @@ import thunder.core.baseutils as baseutils import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable, Positions -from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface +from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface, TagBase from thunder.core.utils import FrozenDict, make_hashable from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map import thunder.core.dtypes as dtypes @@ -351,6 +351,10 @@ def tag_tensorproxy_output_as_detached(proxy): return result +class BoundSymbolTag(TagBase): + pass + + # A symbol, arguments (and kwarguments), output, and sub-symbols # args is a sequence of the arguments # kwargs is a dict of the kwargs @@ -377,6 +381,8 @@ class BoundSymbol(BoundSymbolInterface): source_filename: str | None = None source_positions: Positions | None = None + bsym_tags: set[BoundSymbolTag] = field(default_factory=set) + _call_ctx: None | dict[str, Any] = None _import_ctx: dict = field(default_factory=dict) @@ -412,6 +418,7 @@ def from_bsym(self, **kwargs) -> BoundSymbol: "_import_ctx": self._import_ctx, "_object_ctx": self._object_ctx, "_executor": self._executor, + "bsym_tags": self.bsym_tags.copy(), } self_kwargs.update(kwargs) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index e28821a8b5..05075761cd 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -38,7 +38,7 @@ from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass -from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol +from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, BoundSymbolTag from thunder.core.trace import TraceCtx as Trace from thunder.core.trace import VariableInterface as Variable from thunder.core.trace import ( @@ -92,6 +92,7 @@ TraceTag.register_tag("AUGMENTED_FORWARD") +BoundSymbolTag.register_tag("RECOMPUTE_IN_BACKWARD") # TODO This should be a partial of thunder.trace, but that would cause a circular import @@ -3183,11 +3184,19 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr ) if remat_policy: - rematerializable = remat_policy(all_rematerializable) + rematerializable = remat_policy(fwd_trace, bwd_trace, all_rematerializable) else: - rematerializable = all_rematerializable - - producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args) + rematerializable = { + p + for p in all_rematerializable + if thunder.core.proxies.ProxyTag.RECOMPUTE_IN_BACKWARD in thunder.core.proxies.unvariableify(p).tags + } + + producers = find_producer_symbols( + fwd_trace, + tuple(unvariableify(i) for i in rematerializable), + tuple(fwd_trace.args) + tuple(unvariableify(i) for i in all_rematerializable - rematerializable), + ) required_fw_args = fwd_trace_args & old_saved_for_bwd recomputed_tensors_from_producers = set() @@ -3225,14 +3234,28 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[5].args[0].name == "C1" + proxy_names_to_producers = {} + for bsym in producers: + for p in bsym.flat_proxy_outs: + if variableify(p) in recomputed_tensors_from_producers: + proxy_names_to_producers[p.name] = bsym + + def insert_producer_for_proxy(pname): + producer_bsym = proxy_names_to_producers.get(pname) + if producer_bsym is not None: + for p in producer_bsym.flat_proxy_args: + insert_producer_for_proxy(p.name) + for o in producer_bsym.flat_proxy_outs: + proxy_names_to_producers.pop(o.name, None) + new_bwd_trace.bound_symbols.append(producer_bsym) + for idx, bsym in enumerate(bwd_trace.bound_symbols): if idx == 4: new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward) new_bwd_trace.bound_symbols.append(new_unpack) - elif idx == 6: - new_bwd_trace.bound_symbols.extend(producers) - new_bwd_trace.bound_symbols.append(bsym) else: + for p in bsym.flat_proxy_args: + insert_producer_for_proxy(p.name) new_bwd_trace.bound_symbols.append(bsym) new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]] diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 454582c92e..e0cfbd3adc 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -349,7 +349,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat ) bw_traces.append(bw_extrace) - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + # fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 7cb0fac0e0..54742e25b7 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -184,7 +184,7 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: fusedtrace.bound_symbols = fused_bsyms - fusedtrace = rematerialize(fusedtrace) + # fusedtrace = rematerialize(fusedtrace) fusedtrace = dce(fusedtrace) fusedtrace = update_fusion_call_ctx(fusedtrace) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 38d0353c1f..018ca076b8 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -63,7 +63,6 @@ # NOTE torch is a requirement import torch -import torch.utils.checkpoint import torch._higher_order_ops.wrap import warnings @@ -5325,7 +5324,6 @@ def _unwrap_if_dead(tensor): @torchsymbol( - torch.utils.checkpoint.checkpoint, torch.ops.higher_order.tag_activation_checkpoint, id="activation_checkpoint", ) From 6601bc9906ed91fc30ce7388d588c572ac2d0dc4 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 10:57:11 +0100 Subject: [PATCH 02/17] don't pass traces to policy --- thunder/core/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 05075761cd..06836444a6 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3184,7 +3184,7 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr ) if remat_policy: - rematerializable = remat_policy(fwd_trace, bwd_trace, all_rematerializable) + rematerializable = remat_policy(all_rematerializable) else: rematerializable = { p From 1c1a4357e8e5b87e66352c17f7683740af8be6e8 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 12:05:43 +0100 Subject: [PATCH 03/17] test updates --- thunder/core/transforms.py | 5 ++++ thunder/tests/test_grad.py | 47 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 06836444a6..ae302fc702 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3154,6 +3154,8 @@ def backward_fn(saved_for_backward, cotangents): enable_saved_for_backward_recomputation: None | bool = get_compile_option( "enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation." ) + if enable_saved_for_backward_recomputation is None: + enable_saved_for_backward_recomputation = True if enable_saved_for_backward_recomputation: forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace) @@ -3192,6 +3194,9 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr if thunder.core.proxies.ProxyTag.RECOMPUTE_IN_BACKWARD in thunder.core.proxies.unvariableify(p).tags } + if not rematerializable: + return fwd_trace, bwd_trace + producers = find_producer_symbols( fwd_trace, tuple(unvariableify(i) for i in rematerializable), diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 73e3453e9a..358d9593b0 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1754,8 +1754,8 @@ def f(x, y): # The intermediate values are recomputed during backward pass. assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2 # We detach the saved tensors (which returns a new Python tensor backed by same storage) - assert out.grad_fn.next_functions[0][0].saved_tensors[0].data_ptr() == x.data_ptr() - assert out.grad_fn.next_functions[0][0].saved_tensors[1].data_ptr() == y.data_ptr() + # the order seems to be non-deterministic sometimes + assert {t.data_ptr() for t in out.grad_fn.next_functions[0][0].saved_tensors} == {x.data_ptr(), y.data_ptr()} g = torch.ones_like(out) out.backward(g) @@ -1768,6 +1768,49 @@ def f(x, y): torch.testing.assert_close(y.grad, y_ref.grad) +@requiresCUDA +def test_checkpoint_max_memory(): + import torch.utils.checkpoint + + class Checkpoint(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args): + return torch.utils.checkpoint.checkpoint(self.module, *args, use_reentrant=False) + + with torch.device("cuda:0"): + m = torch.nn.Sequential( + torch.nn.Linear(1024, 16), + torch.nn.ReLU(), + *[ + Checkpoint( + torch.nn.Sequential( + torch.nn.Linear(16, 2048), + torch.nn.Linear(2048, 16), + torch.nn.ReLU(), + ) + ) + for _ in range(10) + ], + torch.nn.Linear(16, 1024), + ) + inps = torch.randn(512, 1024, requires_grad=True) + + jm = thunder.jit(m, executors=()) # no rematerialization + mem_base = torch.cuda.memory_allocated() + torch.cuda.reset_accumulated_memory_stats() + res = jm(inps) + res.sum().backward() + mem_max = torch.cuda.max_memory_allocated() + # the rematerialization pass moved all(?) recomputation to the front, + # making the peak mem about 46MB. + # With checkpointing as coded in the model and recomputation where the + # values are used, we get about 12MB, so we put the barrier at 16MB + assert mem_max - mem_base < 16 * 2**20 + + def test_inconsistent_output_length_grad_transform(): from thunder.extend import OperatorExecutor from thunder.core.proxies import AnyProxy, TensorProxy From b693e10d3154bb4c77a9e6c3477af1134a16dd42 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 12:22:07 +0100 Subject: [PATCH 04/17] switches --- thunder/executors/nvfuserex_impl.py | 5 ++++- thunder/executors/torch_autograd.py | 7 ++++++- thunder/executors/torch_compile.py | 6 +++++- thunder/tests/test_nvfuser_remat.py | 2 +- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 02fa5081da..aef1c3e317 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -925,7 +925,10 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: # Some of the operations might be better placed with its consumers (for # example residual connection in transformer block). This pass moves # them to the consumer. - if self._use_rematerialization: + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization and self._use_rematerialization: fusedtrace = rematerialize(fusedtrace) fusedtrace = remove_redundant_casts(fusedtrace) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index e0cfbd3adc..0c20f4866f 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -4,6 +4,7 @@ import torch import thunder.core.utils as utils +from thunder.core.compile_data import get_compile_option from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify from thunder.core.pytree import tree_flatten @@ -349,7 +350,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat ) bw_traces.append(bw_extrace) - # fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 54742e25b7..c94ba3be70 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -184,7 +184,11 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx: fusedtrace.bound_symbols = fused_bsyms - # fusedtrace = rematerialize(fusedtrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fusedtrace = rematerialize(fusedtrace) fusedtrace = dce(fusedtrace) fusedtrace = update_fusion_call_ctx(fusedtrace) diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index f1218f91dd..eee222706b 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -515,7 +515,7 @@ def forward(self, x): # At the time of writing, linear and matmul are not fused into nvFuser # regions by default therefore, we should enable them separately - jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True) + jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True, use_rematerialization=True) jmodel(inp) def assert_subsymbol_count(trace: TraceCtx, /, num_linears: int, num_matmuls: int, num_fusions: int): From 8f335755c069d9db16ed34460f6642a6a61a0370 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 12:33:51 +0100 Subject: [PATCH 05/17] switches --- thunder/core/rematerialization.py | 3 --- thunder/tests/test_nvfuser_remat.py | 26 +++++++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index ee9b8f4464..f24ee1fad2 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -361,7 +361,6 @@ def add_edges(var): if not required_producer_vars: # If there are no required producer variables, we need to make sure that # the source node is added to the graph. - print("#### hello") add_edge("source", "source", capacity=float("inf")) for var in required_producer_vars: @@ -375,8 +374,6 @@ def add_edges(var): g = nx.DiGraph() g.add_edges_from(edges) - print("#####", dict(g.nodes), g.edges) - try: _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") except Exception: diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index eee222706b..b4eb739964 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -67,7 +67,7 @@ def test_find_producer_symbols(executor, device, _): # We will try to find a subgraph for rematerializing __c and __d t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -117,7 +117,7 @@ def test_find_producer_symbols(executor, device, _): def test_apply_rematerialization_producer(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -151,7 +151,7 @@ def test_apply_rematerialization_producer(executor, device, _): def test_apply_rematerialization_consumer(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -217,7 +217,7 @@ def foo(t0): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(foo, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -257,7 +257,7 @@ def func(t0): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -369,7 +369,9 @@ def func( from thunder.executors.torch_compile import torch_compile_cat_ex try: - compiled_func = thunder.jit(func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor)) + compiled_func = thunder.jit( + func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor), use_rematerialization=True + ) _ = compiled_func( t0, t1, @@ -391,7 +393,7 @@ def func( def test_find_cut(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) intial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(intial_trace.python_callable()) + compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -418,7 +420,7 @@ def test_find_cut_dropout(executor, device, _): with patch("thunder.core.rematerialization.replace_uniform", new=replace_uniform_mock): intial_trace = thunder.trace()(func_with_dropout, t0) - compiled_func = thunder.jit(intial_trace.python_callable()) + compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -459,10 +461,12 @@ def func(t0): # Result with rematerialization and without rematerialization should match initial_trace = thunder.trace()(func, t0) - result_with_remat = thunder.jit(initial_trace.python_callable())(t0) + result_with_remat = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)(t0) assert not isinstance(result_with_remat, Exception) - result_without_remat = disable_rematerialization_in_nvfuser_fusion(thunder.jit(initial_trace.python_callable()))(t0) + result_without_remat = disable_rematerialization_in_nvfuser_fusion( + thunder.jit(initial_trace.python_callable(), use_rematerialization=True) + )(t0) torch.testing.assert_close(result_with_remat, result_without_remat) @@ -474,7 +478,7 @@ def test_rematerialization_name_collision(): def forward(x): return x.softmax(dim=1, dtype=torch.float) - jforward = thunder.jit(forward) + jforward = thunder.jit(forward, use_rematerialization=True) x = torch.randn([32768, 8], dtype=torch.bfloat16, device="cuda", requires_grad=True) From 3b4b4a8e44767eacfcce1a6bf066170874fdd67b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 13:39:09 +0100 Subject: [PATCH 06/17] dce --- thunder/core/jit_ext.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 18de748200..be478a778c 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -966,18 +966,6 @@ def add_input_output_proxy_name(p): o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) return res - from thunder.torch import checkpoint - - # It should be possible to call the general_thunder_jit here to handle the - # conversion from torch to thunder but it doesn't work now - # See https://github.com/Lightning-AI/lightning-thunder/issues/1126 - # TODO: Convert the function to a Thunder function - def thunder_function(*args, **kwargs): - return unwrap(function)(*args, **kwargs) - - XXX - wrapped_thunder_function = wrap_const(thunder_function) - return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) # Adds proxy methods From 9a5e14a31947b675eeb9422f84e20382e5e145ab Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 19:49:40 +0100 Subject: [PATCH 07/17] recompute intermediates from decomposed symbols --- thunder/core/trace_interpreter.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index b38cd01ac7..24e196bb0a 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -6,7 +6,7 @@ from thunder.core.trace import VariableInterface, from_trace, tracectx from thunder.core.baseutils import ProxyInterface, TensorProxyInterface from thunder.core.utils import safe_map_flat, sequencify -from thunder.core.proxies import variableify +from thunder.core.proxies import variableify, ProxyTag from thunder.core.transform_common import VJPDual @@ -183,6 +183,9 @@ def do_swap(v): for new_bsym in new_bsyms: # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? + for o in new_bsym.flat_proxy_outs: + if variableify(o) not in swap_map: + o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD) new_trace.bound_symbols.append( new_bsym.from_bsym_swap_proxies(swap_map).from_bsym( source_filename=bsym.source_filename, source_positions=bsym.source_positions From 1e1bd8345eae08fee142888a9169bb1980e98f2e Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 20:09:19 +0100 Subject: [PATCH 08/17] add/fix tests --- thunder/tests/test_grad.py | 25 +++++++++++++++++++++++++ thunder/tests/test_jit_general.py | 2 +- thunder/tests/test_transforms.py | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 358d9593b0..28021be8cf 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -27,6 +27,7 @@ run_snippet, assert_closer, IN_CI, + NVFUSER_AVAILABLE, requiresCUDA, version_between, ) @@ -1929,3 +1930,27 @@ def func(x): torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual_gr, expected_gr) + +@pytest.mark.parametrize("device", ("cuda", "cpu")) +def test_backward_recomputation_decomposed_ops(device): + def fn(a): + return torch.nn.functional.gelu(a) + + jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False) + jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True) + a = torch.randn(2, 2, device=device, requires_grad=True) + res = jfn(a) + res2 = jfn2(a) + assert len(res.grad_fn.saved_tensors) == 3 # should be decomposed + assert len(res2.grad_fn.saved_tensors) == 1 + + if NVFUSER_AVAILABLE and device == "cuda": + # check everything is fused + assert {bsym.sym.name for bsym in thunder.last_backward_traces(jfn2)[-1].bound_symbols} == { + "nvFusion0", + "clear_mutable_collection", + "python_return", + "python_del", + "unpack_sequence", + "unpack_trivial", + } diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index e1be6d86b8..80e4301398 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1433,7 +1433,7 @@ def foo(a): a = torch.reshape(a, [a.numel()]) return a.relu() - jfoo = thunder.jit(foo, cache="symbolic values") + jfoo = thunder.jit(foo, cache="symbolic values", enable_saved_for_backward_recomputation=False) a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index 2418eeae6a..45f118e7af 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -375,7 +375,7 @@ def forward(self, x): model = MyModel() # Do not recompute anything - jmodel = thunder.jit(model) + jmodel = thunder.jit(model, enable_saved_for_backward_recomputation=False) jmodel(a) fwd_trace = None From c14662c65422dbe97fbd72dfd278fbcbb9f6d295 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 22:39:54 +0100 Subject: [PATCH 09/17] fix proxy accounting --- thunder/core/transforms.py | 24 +++++++++++++++--------- thunder/tests/test_examine_memory.py | 8 ++++---- thunder/tests/test_grad.py | 8 ++++++-- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index ae302fc702..70db8ea65c 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3221,28 +3221,34 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1])) new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE + assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward" + assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE + assert bwd_trace.bound_symbols[4].args[0].name == "C0" + assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE + assert bwd_trace.bound_symbols[5].args[0].name == "C1" + + p_saved_for_backward = bwd_trace.bound_symbols[2].args[0] + p_c0 = bwd_trace.bound_symbols[4].args[0] + new_bwd_trace = from_trace(bwd_trace) # In cases where C0 name is carried from previous trace it must be removed # as the proxy needs to register with that specific name to follow the backward # trace standard signature. - new_bwd_trace.names.discard("C0") - with tracectx(new_bwd_trace): - unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward)) + p_saved_for_backward.coll = (new_saved_for_backward, fwd_trace.output[1][1]) + p_c0.coll = new_saved_for_backward # Here we make sure that the signature of the backward trace is the same as the one we expect. # This part of the trace is the unpacking of the tuple passed from the forward trace, # more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents # used to compute the vector-Jacobian product. - assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE - assert bwd_trace.bound_symbols[4].args[0].name == "C0" - assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE - assert bwd_trace.bound_symbols[5].args[0].name == "C1" proxy_names_to_producers = {} for bsym in producers: + input_names = {p.name for p in bsym.flat_proxy_args} for p in bsym.flat_proxy_outs: - if variableify(p) in recomputed_tensors_from_producers: + if p.name not in input_names and variableify(p) in recomputed_tensors_from_producers: proxy_names_to_producers[p.name] = bsym def insert_producer_for_proxy(pname): @@ -3256,7 +3262,7 @@ def insert_producer_for_proxy(pname): for idx, bsym in enumerate(bwd_trace.bound_symbols): if idx == 4: - new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward) + new_unpack = prims.unpack_sequence.bind(p_c0, len(new_saved_for_backward), output=new_saved_for_backward) new_bwd_trace.bound_symbols.append(new_unpack) else: for p in bsym.flat_proxy_args: diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index 51278003de..f270ecc199 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -113,7 +113,7 @@ def test_nanogpt_block(): # Actual memory usage may vary depending on hardware and cuBLAS settings. # We are checking the estimated memory against a fixed value for consistency. - assert max_mem_fw[0] == 381754368 - assert sum(max_mem_fw[1].values()) == 375462912 - assert max_mem_bw[0] == 437292032 - assert sum(max_mem_bw[1].values()) == 40934400 + assert max_mem_fw[0] == 262183936 + assert sum(max_mem_fw[1].values()) == 135306240 + assert max_mem_bw[0] == 472259584 + assert sum(max_mem_bw[1].values()) == 157341696 diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 28021be8cf..0db8a55513 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1180,12 +1180,15 @@ def func(a, b, *, c): a = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True) - initial_trace = trace(inline_trace=False)(func, a, b, c=c) + jfn = thunder.jit(func) + cd, inps, _ = thunder.compile_data(jfn).get_computation_and_inputs(a, b, c=c) + initial_trace = cd.computation_traces[0] wrapped_trace = wrap_return_value_together_with_arguments(initial_trace) + fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace) fw = executor.make_callable(fw_trace) bw = executor.make_callable(bw_trace) - fw_out, saved_for_backward = fw(a, b, c=c) + fw_out, saved_for_backward = fw(*inps) initial_trace = trace()(value_and_grad(func), a, b, c=c) expected_vjp_func = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True) @@ -1195,6 +1198,7 @@ def func(a, b, *, c): output_grads = tree_map(lambda x: torch.ones_like(x), fw_out["output"]) bw_out = bw(saved_for_backward, output_grads) + expected_grads = (*expected_grads[:-1], expected_grads[-1]["c"]) torch.testing.assert_close(bw_out, expected_grads) From 0a175b14d83755c1ddc9669582394f14a1950b67 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 18 Dec 2024 09:47:19 +0100 Subject: [PATCH 10/17] filter proxies in remat, improve tests --- thunder/core/rematerialization.py | 4 +++- thunder/tests/test_grad.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index f24ee1fad2..1a5241f237 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer( _, leaves = bsym_list_to_dag(list(new_subsymbols)) new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP) proxy_order = order_proxies(new_subsymbols) - new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name])) + new_consumer_args = tuple( + sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name]) + ) new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols) return new_consumer diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 0db8a55513..2eaf67ffb2 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1812,8 +1812,8 @@ def forward(self, *args): # the rematerialization pass moved all(?) recomputation to the front, # making the peak mem about 46MB. # With checkpointing as coded in the model and recomputation where the - # values are used, we get about 12MB, so we put the barrier at 16MB - assert mem_max - mem_base < 16 * 2**20 + # values are used, we get about 12-20MB, so we put the barrier at 24MB + assert mem_max - mem_base < 24 * 2**20 def test_inconsistent_output_length_grad_transform(): @@ -1937,6 +1937,9 @@ def func(x): @pytest.mark.parametrize("device", ("cuda", "cpu")) def test_backward_recomputation_decomposed_ops(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + def fn(a): return torch.nn.functional.gelu(a) From 28723f13c5b4f9ab80f29dc2e0d013ea4c7bd271 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 18 Dec 2024 10:18:12 +0100 Subject: [PATCH 11/17] remat and recomp are exclusive --- thunder/tests/test_nvfuser_remat.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index b4eb739964..d0121a07e4 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -309,7 +309,7 @@ def func(t0): return t3, t4 t0 = make_tensor(2, 2, dtype=torch.float32, device=device) - compiled_func = executor.make_callable(func) + compiled_func = executor.make_callable(func, enable_saved_for_backward_recomputation=False) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[0] @@ -370,7 +370,10 @@ def func( try: compiled_func = thunder.jit( - func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor), use_rematerialization=True + func, + executors=(torch_compile_cat_ex, thunder.nvfuser_executor), + use_rematerialization=True, + enable_saved_for_backward_recomputation=False, ) _ = compiled_func( t0, From 7963c9496de61751d9303f0467f9729a015256da Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 18 Dec 2024 11:02:58 +0100 Subject: [PATCH 12/17] fix cuda mem accounting --- thunder/tests/test_grad.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 2eaf67ffb2..ff7fcc600d 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1804,16 +1804,19 @@ def forward(self, *args): inps = torch.randn(512, 1024, requires_grad=True) jm = thunder.jit(m, executors=()) # no rematerialization + res = jm(inps) + res.sum().backward() + + torch.cuda.reset_peak_memory_stats() mem_base = torch.cuda.memory_allocated() - torch.cuda.reset_accumulated_memory_stats() res = jm(inps) res.sum().backward() mem_max = torch.cuda.max_memory_allocated() - # the rematerialization pass moved all(?) recomputation to the front, - # making the peak mem about 46MB. + # without chewckpointing the peak mem about 43MB. # With checkpointing as coded in the model and recomputation where the - # values are used, we get about 12-20MB, so we put the barrier at 24MB - assert mem_max - mem_base < 24 * 2**20 + # values are used, we get a little over 10MB, so we put the barrier at 16MB + mb_used = (mem_max - mem_base) / 2**20 + assert mb_used < 16 def test_inconsistent_output_length_grad_transform(): From 9125936321b0ab68d49fda9786462e2dbe447a0a Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 19 Dec 2024 08:24:58 +0100 Subject: [PATCH 13/17] random handling in recompute --- thunder/core/transforms.py | 39 ++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 70db8ea65c..32cd946edc 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -38,7 +38,7 @@ from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass -from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, BoundSymbolTag +from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, BoundSymbolTag, has_tags from thunder.core.trace import TraceCtx as Trace from thunder.core.trace import VariableInterface as Variable from thunder.core.trace import ( @@ -2992,6 +2992,7 @@ def unpacking_fn(saved_for_backward, cotangents): ) ) backward_trace.bound_symbols = list((*unpacking_trace.bound_symbols[:-1], *backward_trace_bsyms_without_unpacking)) + backward_trace.scopes[0] = backward_trace.bound_symbols def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces: @@ -3174,10 +3175,18 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr start_time_ns = time.perf_counter_ns() + cd = get_compile_data() + have_nvfuser = any([ex.name == "nvfuser" for ex in cd.executors_list]) if cd is not None else False + if have_nvfuser: + from thunder.core.rematerialization import replace_uniform + + fwd_trace = replace_uniform(fwd_trace) + saved_for_bw = get_saved_for_backward_tensors(fwd_trace) fwd_trace_args = {variableify(j) for j in fwd_trace.args} old_saved_for_bwd = {variableify(j) for j in saved_for_bw} + producers = utils.producers(fwd_trace) all_rematerializable = old_saved_for_bwd - fwd_trace_args remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( @@ -3204,8 +3213,19 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr ) required_fw_args = fwd_trace_args & old_saved_for_bwd + additional_tensors = set() + additional_nontensors = set() + recomputed_tensors_from_producers = set() for prod in producers: + if has_tags(prod, {prims.OpTags.RANDOM_OP}): + for prod_out in prod.flat_proxy_outs: + # things that are from inputs? + if isinstance(prod_out, TensorProxy): + additional_tensors.add(variableify(prod_out)) + else: + additional_nontensors.add(variableify(prod_out)) + continue for prod_arg in prod.flat_args: prod_arg = variableify(prod_arg) if prod_arg in fwd_trace_args: @@ -3214,12 +3234,13 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr recomputed_tensors_from_producers.add(variableify(prod_out)) required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers - new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd) + new_saved_for_backward = tuple( + unvariableify(i) for i in required_fw_args | required_saved_for_bwd | additional_tensors + ) new_fwd_trace = from_trace(fwd_trace) new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() - new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1])) - new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + new_fwd_trace.scopes[0] = new_fwd_trace.bound_symbols assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward" @@ -3230,6 +3251,12 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr p_saved_for_backward = bwd_trace.bound_symbols[2].args[0] p_c0 = bwd_trace.bound_symbols[4].args[0] + p_c1 = bwd_trace.bound_symbols[5].args[0] + + new_c1 = tuple(unvariableify(i) for i in {variableify(p) for p in p_c1.coll} | additional_nontensors) + + new_return_args = (fwd_trace.output[0], (new_saved_for_backward, new_c1)) + new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) new_bwd_trace = from_trace(bwd_trace) # In cases where C0 name is carried from previous trace it must be removed @@ -3238,6 +3265,7 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr p_saved_for_backward.coll = (new_saved_for_backward, fwd_trace.output[1][1]) p_c0.coll = new_saved_for_backward + p_c1.coll = new_c1 # Here we make sure that the signature of the backward trace is the same as the one we expect. # This part of the trace is the unpacking of the tuple passed from the forward trace, @@ -3264,6 +3292,9 @@ def insert_producer_for_proxy(pname): if idx == 4: new_unpack = prims.unpack_sequence.bind(p_c0, len(new_saved_for_backward), output=new_saved_for_backward) new_bwd_trace.bound_symbols.append(new_unpack) + elif idx == 5: + new_unpack = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1) + new_bwd_trace.bound_symbols.append(new_unpack) else: for p in bsym.flat_proxy_args: insert_producer_for_proxy(p.name) From fd5590391d8eaa0afed9cb214a24ee3929582e88 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 19 Dec 2024 13:33:33 +0100 Subject: [PATCH 14/17] del last used in torchcompile --- thunder/executors/torch_compile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index c94ba3be70..e3df7c9440 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -16,6 +16,7 @@ from thunder.executors.passes import ( update_fusion_call_ctx, transform_for_execution, + del_last_used, ) from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo, fuse_bound_symbols @@ -95,6 +96,7 @@ def make_compiled( region_trace._siginfo.args = [(a.name, None) for a in region_trace.args] torchex_trace = transform_for_execution(region_trace, executors_list=(pytorch_ex,)) + torchex_trace = del_last_used(torchex_trace) trace_callable = torchex_trace.python_callable(include_decorators=False) torch_compile_fullgraph: None | bool = get_compile_option( From a218e1df4fbf53c9b5d3d6e4ca5df95ab5b83d9f Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 20 Dec 2024 13:22:10 +0100 Subject: [PATCH 15/17] updates --- thunder/benchmarks/benchmark_litgpt.py | 17 +-- thunder/core/transforms.py | 155 +++++++++++-------------- 2 files changed, 79 insertions(+), 93 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 13e6305b85..50421188a8 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -541,19 +541,15 @@ def setup_distributed(self, model): return model def setup_activation_checkpointing(self): - if "thunder" in self.compile and "dynamo" not in self.compile: - # checkpointing is an option to thunder.jit - return - if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()): warnings.warn( "FSDP checkpointing is configured, but the model already contains checkpointed layers." " Checkpointing will be ignored." ) return - check_fn = lambda submodule: isinstance(submodule, Block) apply_activation_checkpointing(self.model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + print(self.model) # TODO(crcrpar): Think of apply `torch.compile` or `thunder.jit` per block/module # like https://github.com/pytorch/torchtitan/blob/cfc0f4e/torchtitan/parallelisms/parallelize_llama.py#L275-L284 @@ -890,12 +886,19 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########") print(thunder.last_backward_traces(thunder_fn)[-1]) else: + from thunder.examine.memory_calculation import get_alloc_memory + from thunder.executors.passes import del_last_used + for i, f_traces in enumerate(fwd_traces, start=1): print(f"##########\n#{i}-th ThunderModule\n##########") - print(f_traces[-1]) + print(f_traces) for i, b_traces in enumerate(bwd_traces, start=1): print(f"##########\n#{i}-th ThunderModule\n##########") - print(b_traces[-1]) + for tr in b_traces: + dltr = del_last_used(tr) + tr_peak_memory, _ = get_alloc_memory(dltr) + print(f"#the following trace uses ~{tr_peak_memory/(2**30):.2f}GB memory") + print(tr) if global_rank in [0, None]: if return_metrics_as_json: diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 32cd946edc..37c4da08f6 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -59,6 +59,7 @@ unzip2, const_as, sequencify, + OrderedSet, ProxyDict, find_producer_symbols, ) @@ -3183,11 +3184,24 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr fwd_trace = replace_uniform(fwd_trace) saved_for_bw = get_saved_for_backward_tensors(fwd_trace) - fwd_trace_args = {variableify(j) for j in fwd_trace.args} - old_saved_for_bwd = {variableify(j) for j in saved_for_bw} + fwd_trace_args = OrderedSet(variableify(j) for j in fwd_trace.args) + old_saved_for_bwd = OrderedSet(variableify(j) for j in saved_for_bw) - producers = utils.producers(fwd_trace) - all_rematerializable = old_saved_for_bwd - fwd_trace_args + all_proxies = fwd_trace_args.copy() + all_recomputable_proxies = OrderedSet() + + proxy_names_to_producers = {} + for bsym in fwd_trace.bound_symbols: + for o in bsym.flat_proxy_outs: + vo = variableify(o) + + if vo in all_proxies: + continue + + proxy_names_to_producers[o.name] = bsym + all_proxies.add(vo) + if ProxyTag.RECOMPUTE_IN_BACKWARD in o.tags and not has_tags(bsym, {prims.OpTags.RANDOM_OP}): + all_recomputable_proxies.add(vo) remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option( "recomputation_policy", @@ -3195,112 +3209,81 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr ) if remat_policy: - rematerializable = remat_policy(all_rematerializable) + rematerializable = remat_policy(old_saved_for_bwd - fwd_trace_args) else: - rematerializable = { - p - for p in all_rematerializable - if thunder.core.proxies.ProxyTag.RECOMPUTE_IN_BACKWARD in thunder.core.proxies.unvariableify(p).tags - } + rematerializable = old_saved_for_bwd & all_recomputable_proxies if not rematerializable: return fwd_trace, bwd_trace - producers = find_producer_symbols( - fwd_trace, - tuple(unvariableify(i) for i in rematerializable), - tuple(fwd_trace.args) + tuple(unvariableify(i) for i in all_rematerializable - rematerializable), - ) - - required_fw_args = fwd_trace_args & old_saved_for_bwd - additional_tensors = set() - additional_nontensors = set() - - recomputed_tensors_from_producers = set() - for prod in producers: - if has_tags(prod, {prims.OpTags.RANDOM_OP}): - for prod_out in prod.flat_proxy_outs: - # things that are from inputs? - if isinstance(prod_out, TensorProxy): - additional_tensors.add(variableify(prod_out)) - else: - additional_nontensors.add(variableify(prod_out)) - continue - for prod_arg in prod.flat_args: - prod_arg = variableify(prod_arg) - if prod_arg in fwd_trace_args: - required_fw_args.add(prod_arg) - for prod_out in prod.flat_outs: - recomputed_tensors_from_producers.add(variableify(prod_out)) - - required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers - new_saved_for_backward = tuple( - unvariableify(i) for i in required_fw_args | required_saved_for_bwd | additional_tensors - ) - - new_fwd_trace = from_trace(fwd_trace) - new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() - new_fwd_trace.scopes[0] = new_fwd_trace.bound_symbols - + # Here we make sure that the signature of the backward trace is the same as the one we expect. + # This part of the trace is the unpacking of the tuple passed from the forward trace, + # more specifically, C0 unpacks into the saved for backward tensors and C1 into the nontensors + # the cotangents used to compute the vector-Jacobian product are a separate argument. assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward" assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[4].args[0].name == "C0" assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE assert bwd_trace.bound_symbols[5].args[0].name == "C1" - p_saved_for_backward = bwd_trace.bound_symbols[2].args[0] p_c0 = bwd_trace.bound_symbols[4].args[0] p_c1 = bwd_trace.bound_symbols[5].args[0] - new_c1 = tuple(unvariableify(i) for i in {variableify(p) for p in p_c1.coll} | additional_nontensors) - - new_return_args = (fwd_trace.output[0], (new_saved_for_backward, new_c1)) - new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + saved_tensors = fwd_trace_args & old_saved_for_bwd + saved_nontensors = OrderedSet(variableify(p) for p in p_c1.coll) new_bwd_trace = from_trace(bwd_trace) - # In cases where C0 name is carried from previous trace it must be removed - # as the proxy needs to register with that specific name to follow the backward - # trace standard signature. - - p_saved_for_backward.coll = (new_saved_for_backward, fwd_trace.output[1][1]) - p_c0.coll = new_saved_for_backward - p_c1.coll = new_c1 - # Here we make sure that the signature of the backward trace is the same as the one we expect. - # This part of the trace is the unpacking of the tuple passed from the forward trace, - # more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents - # used to compute the vector-Jacobian product. - - proxy_names_to_producers = {} - for bsym in producers: - input_names = {p.name for p in bsym.flat_proxy_args} - for p in bsym.flat_proxy_outs: - if p.name not in input_names and variableify(p) in recomputed_tensors_from_producers: - proxy_names_to_producers[p.name] = bsym - - def insert_producer_for_proxy(pname): - producer_bsym = proxy_names_to_producers.get(pname) - if producer_bsym is not None: - for p in producer_bsym.flat_proxy_args: - insert_producer_for_proxy(p.name) - for o in producer_bsym.flat_proxy_outs: - proxy_names_to_producers.pop(o.name, None) - new_bwd_trace.bound_symbols.append(producer_bsym) + # args will be added from unpack_trivial + have_in_backward = saved_tensors | saved_nontensors + + def compute_proxy_from_producer(p): + vp = variableify(p) + if vp in have_in_backward: + return + if vp not in all_recomputable_proxies: + have_in_backward.add(vp) + if isinstance(p, TensorProxy): + saved_tensors.add(vp) + else: + saved_nontensors.add(vp) + return + producer_bsym = proxy_names_to_producers[p.name] + for p in producer_bsym.flat_proxy_args: + compute_proxy_from_producer(p) + for o in producer_bsym.flat_proxy_outs: + have_in_backward.add(variableify(o)) + new_bwd_trace.bound_symbols.append(producer_bsym) for idx, bsym in enumerate(bwd_trace.bound_symbols): - if idx == 4: - new_unpack = prims.unpack_sequence.bind(p_c0, len(new_saved_for_backward), output=new_saved_for_backward) - new_bwd_trace.bound_symbols.append(new_unpack) - elif idx == 5: - new_unpack = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1) - new_bwd_trace.bound_symbols.append(new_unpack) + if idx in {4, 5}: + # handled later + new_bwd_trace.bound_symbols.append(bsym) else: for p in bsym.flat_proxy_args: - insert_producer_for_proxy(p.name) + compute_proxy_from_producer(p) + for o in bsym.flat_proxy_outs: + have_in_backward.add(variableify(o)) new_bwd_trace.bound_symbols.append(bsym) - new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]] + new_fwd_trace = from_trace(fwd_trace) + new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() + + # TODO: fix ordering... + new_c0 = tuple(unvariableify(i) for i in saved_tensors) + new_c1 = tuple(unvariableify(i) for i in saved_nontensors) + + new_return_args = (fwd_trace.output[0], (new_c0, new_c1)) + new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) + + p_saved_for_backward.coll = (new_c0, new_c1) + p_c0.coll = new_c0 + p_c1.coll = new_c1 + + new_bwd_trace.bound_symbols[4] = prims.unpack_sequence.bind(p_c0, len(new_c0), output=new_c0) + new_bwd_trace.bound_symbols[5] = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1) + new_bwd_trace.args = [(new_c0, new_c1), *bwd_trace.args[1:]] elapsed_time_ns = time.perf_counter_ns() - start_time_ns new_bwd_trace.set_provenance( From 4ed48f60ad9bdf88957f99fe6516aa47298b7632 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 20 Dec 2024 14:41:46 +0100 Subject: [PATCH 16/17] memory annotation --- thunder/examine/memory_calculation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/examine/memory_calculation.py b/thunder/examine/memory_calculation.py index 14c2123ed2..c4255773c6 100644 --- a/thunder/examine/memory_calculation.py +++ b/thunder/examine/memory_calculation.py @@ -147,7 +147,7 @@ def clear_mutable_collection_argument_memory( return memory_size -def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]: +def get_alloc_memory(trc: TraceCtx, *, annotate=False) -> tuple[int, dict[str, int]]: """ Calculate the memory usage based on the executable trace. The memory calculation is based only on the compile-time trace, i.e. the input and output shape @@ -189,6 +189,10 @@ def get_alloc_memory(trc: TraceCtx) -> tuple[int, dict[str, int]]: impl = partial(impl, is_argument=is_argument) allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory) + if annotate: + if bsym.header: + bsym.header += " " + bsym.header += f"mem after next op: ~{allocated/(2**30):2f}GB" max_allocated = max(max_allocated, allocated) return max_allocated, name_to_alloc_memory From 53a3277428efd9fbc89ae4902bdd4d7d46660ce7 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 7 Jan 2025 11:08:45 +0100 Subject: [PATCH 17/17] adjust tests details --- thunder/tests/test_grad.py | 5 +++-- thunder/tests/test_networks.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index ff7fcc600d..0d0a3da6da 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1938,6 +1938,7 @@ def func(x): torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual_gr, expected_gr) + @pytest.mark.parametrize("device", ("cuda", "cpu")) def test_backward_recomputation_decomposed_ops(device): if device == "cuda" and not torch.cuda.is_available(): @@ -1951,8 +1952,8 @@ def fn(a): a = torch.randn(2, 2, device=device, requires_grad=True) res = jfn(a) res2 = jfn2(a) - assert len(res.grad_fn.saved_tensors) == 3 # should be decomposed - assert len(res2.grad_fn.saved_tensors) == 1 + assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed + assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1 if NVFUSER_AVAILABLE and device == "cuda": # check everything is fused diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 372d2cfe2e..fab3e500e2 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -553,4 +553,4 @@ def test_hf_llama(): top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols} # changes this to fewer as needed, the goal is to not have too many fusions - assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7 + assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 8